liuhahi commited on
Commit
9cb8fca
·
1 Parent(s): 9671d2b

clear cache

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+
37
+ # ignore all pyc files
38
+ *.pyc filter=lfs diff=lfs merge=lfs -text
InstantID/__pycache__/CrossAttentionPatch.cpython-310.pyc DELETED
Binary file (4.86 kB)
 
InstantID/__pycache__/resampler.cpython-310.pyc DELETED
Binary file (3.16 kB)
 
InstantID/__pycache__/utils.cpython-310.pyc DELETED
Binary file (940 Bytes)
 
InstantID/comfy/__pycache__/utils.cpython-310.pyc DELETED
Binary file (879 Bytes)
 
InstantID/comfy/ldm/modules/__pycache__/attention.cpython-310.pyc DELETED
Binary file (1.93 kB)
 
InstantID/comfy/model_management.py DELETED
@@ -1,1158 +0,0 @@
1
- """
2
- This file is part of ComfyUI.
3
- Copyright (C) 2024 Comfy
4
-
5
- This program is free software: you can redistribute it and/or modify
6
- it under the terms of the GNU General Public License as published by
7
- the Free Software Foundation, either version 3 of the License, or
8
- (at your option) any later version.
9
-
10
- This program is distributed in the hope that it will be useful,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- GNU General Public License for more details.
14
-
15
- You should have received a copy of the GNU General Public License
16
- along with this program. If not, see <https://www.gnu.org/licenses/>.
17
- """
18
-
19
- import psutil
20
- import logging
21
- from enum import Enum
22
- from comfy.cli_args import args
23
- import torch
24
- import sys
25
- import platform
26
- import weakref
27
- import gc
28
-
29
- class VRAMState(Enum):
30
- DISABLED = 0 #No vram present: no need to move models to vram
31
- NO_VRAM = 1 #Very low vram: enable all the options to save vram
32
- LOW_VRAM = 2
33
- NORMAL_VRAM = 3
34
- HIGH_VRAM = 4
35
- SHARED = 5 #No dedicated vram: memory shared between CPU and GPU but models still need to be moved between both.
36
-
37
- class CPUState(Enum):
38
- GPU = 0
39
- CPU = 1
40
- MPS = 2
41
-
42
- # Determine VRAM State
43
- vram_state = VRAMState.NORMAL_VRAM
44
- set_vram_to = VRAMState.NORMAL_VRAM
45
- cpu_state = CPUState.GPU
46
-
47
- total_vram = 0
48
-
49
- xpu_available = False
50
- torch_version = ""
51
- try:
52
- torch_version = torch.version.__version__
53
- xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
54
- except:
55
- pass
56
-
57
- lowvram_available = True
58
- if args.deterministic:
59
- logging.info("Using deterministic algorithms for pytorch")
60
- torch.use_deterministic_algorithms(True, warn_only=True)
61
-
62
- directml_enabled = False
63
- if args.directml is not None:
64
- import torch_directml
65
- directml_enabled = True
66
- device_index = args.directml
67
- if device_index < 0:
68
- directml_device = torch_directml.device()
69
- else:
70
- directml_device = torch_directml.device(device_index)
71
- logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
72
- # torch_directml.disable_tiled_resources(True)
73
- lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
74
-
75
- try:
76
- import intel_extension_for_pytorch as ipex
77
- _ = torch.xpu.device_count()
78
- xpu_available = xpu_available or torch.xpu.is_available()
79
- except:
80
- xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
81
-
82
- try:
83
- if torch.backends.mps.is_available():
84
- cpu_state = CPUState.MPS
85
- import torch.mps
86
- except:
87
- pass
88
-
89
- try:
90
- import torch_npu # noqa: F401
91
- _ = torch.npu.device_count()
92
- npu_available = torch.npu.is_available()
93
- except:
94
- npu_available = False
95
-
96
- if args.cpu:
97
- cpu_state = CPUState.CPU
98
-
99
- def is_intel_xpu():
100
- global cpu_state
101
- global xpu_available
102
- if cpu_state == CPUState.GPU:
103
- if xpu_available:
104
- return True
105
- return False
106
-
107
- def is_ascend_npu():
108
- global npu_available
109
- if npu_available:
110
- return True
111
- return False
112
-
113
- def get_torch_device():
114
- global directml_enabled
115
- global cpu_state
116
- if directml_enabled:
117
- global directml_device
118
- return directml_device
119
- if cpu_state == CPUState.MPS:
120
- return torch.device("mps")
121
- if cpu_state == CPUState.CPU:
122
- return torch.device("cpu")
123
- else:
124
- if is_intel_xpu():
125
- return torch.device("xpu", torch.xpu.current_device())
126
- elif is_ascend_npu():
127
- return torch.device("npu", torch.npu.current_device())
128
- else:
129
- return torch.device(torch.cuda.current_device())
130
-
131
- def get_total_memory(dev=None, torch_total_too=False):
132
- global directml_enabled
133
- if dev is None:
134
- dev = get_torch_device()
135
-
136
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
137
- mem_total = psutil.virtual_memory().total
138
- mem_total_torch = mem_total
139
- else:
140
- if directml_enabled:
141
- mem_total = 1024 * 1024 * 1024 #TODO
142
- mem_total_torch = mem_total
143
- elif is_intel_xpu():
144
- stats = torch.xpu.memory_stats(dev)
145
- mem_reserved = stats['reserved_bytes.all.current']
146
- mem_total_torch = mem_reserved
147
- mem_total = torch.xpu.get_device_properties(dev).total_memory
148
- elif is_ascend_npu():
149
- stats = torch.npu.memory_stats(dev)
150
- mem_reserved = stats['reserved_bytes.all.current']
151
- _, mem_total_npu = torch.npu.mem_get_info(dev)
152
- mem_total_torch = mem_reserved
153
- mem_total = mem_total_npu
154
- else:
155
- stats = torch.cuda.memory_stats(dev)
156
- mem_reserved = stats['reserved_bytes.all.current']
157
- _, mem_total_cuda = torch.cuda.mem_get_info(dev)
158
- mem_total_torch = mem_reserved
159
- mem_total = mem_total_cuda
160
-
161
- if torch_total_too:
162
- return (mem_total, mem_total_torch)
163
- else:
164
- return mem_total
165
-
166
- total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
167
- total_ram = psutil.virtual_memory().total / (1024 * 1024)
168
- logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
169
-
170
- try:
171
- logging.info("pytorch version: {}".format(torch_version))
172
- except:
173
- pass
174
-
175
- try:
176
- OOM_EXCEPTION = torch.cuda.OutOfMemoryError
177
- except:
178
- OOM_EXCEPTION = Exception
179
-
180
- XFORMERS_VERSION = ""
181
- XFORMERS_ENABLED_VAE = True
182
- if args.disable_xformers:
183
- XFORMERS_IS_AVAILABLE = False
184
- else:
185
- try:
186
- import xformers
187
- import xformers.ops
188
- XFORMERS_IS_AVAILABLE = True
189
- try:
190
- XFORMERS_IS_AVAILABLE = xformers._has_cpp_library
191
- except:
192
- pass
193
- try:
194
- XFORMERS_VERSION = xformers.version.__version__
195
- logging.info("xformers version: {}".format(XFORMERS_VERSION))
196
- if XFORMERS_VERSION.startswith("0.0.18"):
197
- logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
198
- logging.warning("Please downgrade or upgrade xformers to a different version.\n")
199
- XFORMERS_ENABLED_VAE = False
200
- except:
201
- pass
202
- except:
203
- XFORMERS_IS_AVAILABLE = False
204
-
205
- def is_nvidia():
206
- global cpu_state
207
- if cpu_state == CPUState.GPU:
208
- if torch.version.cuda:
209
- return True
210
- return False
211
-
212
- def is_amd():
213
- global cpu_state
214
- if cpu_state == CPUState.GPU:
215
- if torch.version.hip:
216
- return True
217
- return False
218
-
219
- MIN_WEIGHT_MEMORY_RATIO = 0.4
220
- if is_nvidia():
221
- MIN_WEIGHT_MEMORY_RATIO = 0.2
222
-
223
- ENABLE_PYTORCH_ATTENTION = False
224
- if args.use_pytorch_cross_attention:
225
- ENABLE_PYTORCH_ATTENTION = True
226
- XFORMERS_IS_AVAILABLE = False
227
-
228
- try:
229
- if is_nvidia():
230
- if int(torch_version[0]) >= 2:
231
- if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
232
- ENABLE_PYTORCH_ATTENTION = True
233
- if is_intel_xpu() or is_ascend_npu():
234
- if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
235
- ENABLE_PYTORCH_ATTENTION = True
236
- except:
237
- pass
238
-
239
- if ENABLE_PYTORCH_ATTENTION:
240
- torch.backends.cuda.enable_math_sdp(True)
241
- torch.backends.cuda.enable_flash_sdp(True)
242
- torch.backends.cuda.enable_mem_efficient_sdp(True)
243
-
244
- try:
245
- if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
246
- torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
247
- except:
248
- logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
249
-
250
- if args.lowvram:
251
- set_vram_to = VRAMState.LOW_VRAM
252
- lowvram_available = True
253
- elif args.novram:
254
- set_vram_to = VRAMState.NO_VRAM
255
- elif args.highvram or args.gpu_only:
256
- vram_state = VRAMState.HIGH_VRAM
257
-
258
- FORCE_FP32 = False
259
- FORCE_FP16 = False
260
- if args.force_fp32:
261
- logging.info("Forcing FP32, if this improves things please report it.")
262
- FORCE_FP32 = True
263
-
264
- if args.force_fp16:
265
- logging.info("Forcing FP16.")
266
- FORCE_FP16 = True
267
-
268
- if lowvram_available:
269
- if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
270
- vram_state = set_vram_to
271
-
272
-
273
- if cpu_state != CPUState.GPU:
274
- vram_state = VRAMState.DISABLED
275
-
276
- if cpu_state == CPUState.MPS:
277
- vram_state = VRAMState.SHARED
278
-
279
- logging.info(f"Set vram state to: {vram_state.name}")
280
-
281
- DISABLE_SMART_MEMORY = args.disable_smart_memory
282
-
283
- if DISABLE_SMART_MEMORY:
284
- logging.info("Disabling smart memory management")
285
-
286
- def get_torch_device_name(device):
287
- if hasattr(device, 'type'):
288
- if device.type == "cuda":
289
- try:
290
- allocator_backend = torch.cuda.get_allocator_backend()
291
- except:
292
- allocator_backend = ""
293
- return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
294
- else:
295
- return "{}".format(device.type)
296
- elif is_intel_xpu():
297
- return "{} {}".format(device, torch.xpu.get_device_name(device))
298
- elif is_ascend_npu():
299
- return "{} {}".format(device, torch.npu.get_device_name(device))
300
- else:
301
- return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
302
-
303
- try:
304
- logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
305
- except:
306
- logging.warning("Could not pick default device.")
307
-
308
-
309
- current_loaded_models = []
310
-
311
- def module_size(module):
312
- module_mem = 0
313
- sd = module.state_dict()
314
- for k in sd:
315
- t = sd[k]
316
- module_mem += t.nelement() * t.element_size()
317
- return module_mem
318
-
319
- class LoadedModel:
320
- def __init__(self, model):
321
- self._set_model(model)
322
- self.device = model.load_device
323
- self.real_model = None
324
- self.currently_used = True
325
- self.model_finalizer = None
326
- self._patcher_finalizer = None
327
-
328
- def _set_model(self, model):
329
- self._model = weakref.ref(model)
330
- if model.parent is not None:
331
- self._parent_model = weakref.ref(model.parent)
332
- self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
333
-
334
- def _switch_parent(self):
335
- model = self._parent_model()
336
- if model is not None:
337
- self._set_model(model)
338
-
339
- @property
340
- def model(self):
341
- return self._model()
342
-
343
- def model_memory(self):
344
- return self.model.model_size()
345
-
346
- def model_loaded_memory(self):
347
- return self.model.loaded_size()
348
-
349
- def model_offloaded_memory(self):
350
- return self.model.model_size() - self.model.loaded_size()
351
-
352
- def model_memory_required(self, device):
353
- if device == self.model.current_loaded_device():
354
- return self.model_offloaded_memory()
355
- else:
356
- return self.model_memory()
357
-
358
- def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
359
- self.model.model_patches_to(self.device)
360
- self.model.model_patches_to(self.model.model_dtype())
361
-
362
- # if self.model.loaded_size() > 0:
363
- use_more_vram = lowvram_model_memory
364
- if use_more_vram == 0:
365
- use_more_vram = 1e32
366
- self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
367
- real_model = self.model.model
368
-
369
- if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
370
- with torch.no_grad():
371
- real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
372
-
373
- self.real_model = weakref.ref(real_model)
374
- self.model_finalizer = weakref.finalize(real_model, cleanup_models)
375
- return real_model
376
-
377
- def should_reload_model(self, force_patch_weights=False):
378
- if force_patch_weights and self.model.lowvram_patch_counter() > 0:
379
- return True
380
- return False
381
-
382
- def model_unload(self, memory_to_free=None, unpatch_weights=True):
383
- if memory_to_free is not None:
384
- if memory_to_free < self.model.loaded_size():
385
- freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
386
- if freed >= memory_to_free:
387
- return False
388
- self.model.detach(unpatch_weights)
389
- self.model_finalizer.detach()
390
- self.model_finalizer = None
391
- self.real_model = None
392
- return True
393
-
394
- def model_use_more_vram(self, extra_memory, force_patch_weights=False):
395
- return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights)
396
-
397
- def __eq__(self, other):
398
- return self.model is other.model
399
-
400
- def __del__(self):
401
- if self._patcher_finalizer is not None:
402
- self._patcher_finalizer.detach()
403
-
404
- def is_dead(self):
405
- return self.real_model() is not None and self.model is None
406
-
407
-
408
- def use_more_memory(extra_memory, loaded_models, device):
409
- for m in loaded_models:
410
- if m.device == device:
411
- extra_memory -= m.model_use_more_vram(extra_memory)
412
- if extra_memory <= 0:
413
- break
414
-
415
- def offloaded_memory(loaded_models, device):
416
- offloaded_mem = 0
417
- for m in loaded_models:
418
- if m.device == device:
419
- offloaded_mem += m.model_offloaded_memory()
420
- return offloaded_mem
421
-
422
- WINDOWS = any(platform.win32_ver())
423
-
424
- EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
425
- if WINDOWS:
426
- EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
427
-
428
- if args.reserve_vram is not None:
429
- EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
430
- logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
431
-
432
- def extra_reserved_memory():
433
- return EXTRA_RESERVED_VRAM
434
-
435
- def minimum_inference_memory():
436
- return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
437
-
438
- def free_memory(memory_required, device, keep_loaded=[]):
439
- cleanup_models_gc()
440
- unloaded_model = []
441
- can_unload = []
442
- unloaded_models = []
443
-
444
- for i in range(len(current_loaded_models) -1, -1, -1):
445
- shift_model = current_loaded_models[i]
446
- if shift_model.device == device:
447
- if shift_model not in keep_loaded and not shift_model.is_dead():
448
- can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
449
- shift_model.currently_used = False
450
-
451
- for x in sorted(can_unload):
452
- i = x[-1]
453
- memory_to_free = None
454
- if not DISABLE_SMART_MEMORY:
455
- free_mem = get_free_memory(device)
456
- if free_mem > memory_required:
457
- break
458
- memory_to_free = memory_required - free_mem
459
- logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
460
- if current_loaded_models[i].model_unload(memory_to_free):
461
- unloaded_model.append(i)
462
-
463
- for i in sorted(unloaded_model, reverse=True):
464
- unloaded_models.append(current_loaded_models.pop(i))
465
-
466
- if len(unloaded_model) > 0:
467
- soft_empty_cache()
468
- else:
469
- if vram_state != VRAMState.HIGH_VRAM:
470
- mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
471
- if mem_free_torch > mem_free_total * 0.25:
472
- soft_empty_cache()
473
- return unloaded_models
474
-
475
- def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
476
- cleanup_models_gc()
477
- global vram_state
478
-
479
- inference_memory = minimum_inference_memory()
480
- extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
481
- if minimum_memory_required is None:
482
- minimum_memory_required = extra_mem
483
- else:
484
- minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
485
-
486
- models = set(models)
487
-
488
- models_to_load = []
489
-
490
- for x in models:
491
- loaded_model = LoadedModel(x)
492
- try:
493
- loaded_model_index = current_loaded_models.index(loaded_model)
494
- except:
495
- loaded_model_index = None
496
-
497
- if loaded_model_index is not None:
498
- loaded = current_loaded_models[loaded_model_index]
499
- loaded.currently_used = True
500
- models_to_load.append(loaded)
501
- else:
502
- if hasattr(x, "model"):
503
- logging.info(f"Requested to load {x.model.__class__.__name__}")
504
- models_to_load.append(loaded_model)
505
-
506
- for loaded_model in models_to_load:
507
- to_unload = []
508
- for i in range(len(current_loaded_models)):
509
- if loaded_model.model.is_clone(current_loaded_models[i].model):
510
- to_unload = [i] + to_unload
511
- for i in to_unload:
512
- current_loaded_models.pop(i).model.detach(unpatch_all=False)
513
-
514
- total_memory_required = {}
515
- for loaded_model in models_to_load:
516
- total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
517
-
518
- for device in total_memory_required:
519
- if device != torch.device("cpu"):
520
- free_memory(total_memory_required[device] * 1.1 + extra_mem, device)
521
-
522
- for device in total_memory_required:
523
- if device != torch.device("cpu"):
524
- free_mem = get_free_memory(device)
525
- if free_mem < minimum_memory_required:
526
- models_l = free_memory(minimum_memory_required, device)
527
- logging.info("{} models unloaded.".format(len(models_l)))
528
-
529
- for loaded_model in models_to_load:
530
- model = loaded_model.model
531
- torch_dev = model.load_device
532
- if is_device_cpu(torch_dev):
533
- vram_set_state = VRAMState.DISABLED
534
- else:
535
- vram_set_state = vram_state
536
- lowvram_model_memory = 0
537
- if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
538
- model_size = loaded_model.model_memory_required(torch_dev)
539
- loaded_memory = loaded_model.model_loaded_memory()
540
- current_free_mem = get_free_memory(torch_dev) + loaded_memory
541
-
542
- lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
543
- lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
544
- if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
545
- lowvram_model_memory = 0
546
-
547
- if vram_set_state == VRAMState.NO_VRAM:
548
- lowvram_model_memory = 0.1
549
-
550
- loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
551
- current_loaded_models.insert(0, loaded_model)
552
- return
553
-
554
- def load_model_gpu(model):
555
- return load_models_gpu([model])
556
-
557
- def loaded_models(only_currently_used=False):
558
- output = []
559
- for m in current_loaded_models:
560
- if only_currently_used:
561
- if not m.currently_used:
562
- continue
563
-
564
- output.append(m.model)
565
- return output
566
-
567
-
568
- def cleanup_models_gc():
569
- do_gc = False
570
- for i in range(len(current_loaded_models)):
571
- cur = current_loaded_models[i]
572
- if cur.is_dead():
573
- logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
574
- do_gc = True
575
- break
576
-
577
- if do_gc:
578
- gc.collect()
579
- soft_empty_cache()
580
-
581
- for i in range(len(current_loaded_models)):
582
- cur = current_loaded_models[i]
583
- if cur.is_dead():
584
- logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
585
-
586
-
587
-
588
- def cleanup_models():
589
- to_delete = []
590
- for i in range(len(current_loaded_models)):
591
- if current_loaded_models[i].real_model() is None:
592
- to_delete = [i] + to_delete
593
-
594
- for i in to_delete:
595
- x = current_loaded_models.pop(i)
596
- del x
597
-
598
- def dtype_size(dtype):
599
- dtype_size = 4
600
- if dtype == torch.float16 or dtype == torch.bfloat16:
601
- dtype_size = 2
602
- elif dtype == torch.float32:
603
- dtype_size = 4
604
- else:
605
- try:
606
- dtype_size = dtype.itemsize
607
- except: #Old pytorch doesn't have .itemsize
608
- pass
609
- return dtype_size
610
-
611
- def unet_offload_device():
612
- if vram_state == VRAMState.HIGH_VRAM:
613
- return get_torch_device()
614
- else:
615
- return torch.device("cpu")
616
-
617
- def unet_inital_load_device(parameters, dtype):
618
- torch_dev = get_torch_device()
619
- if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
620
- return torch_dev
621
-
622
- cpu_dev = torch.device("cpu")
623
- if DISABLE_SMART_MEMORY:
624
- return cpu_dev
625
-
626
- model_size = dtype_size(dtype) * parameters
627
-
628
- mem_dev = get_free_memory(torch_dev)
629
- mem_cpu = get_free_memory(cpu_dev)
630
- if mem_dev > mem_cpu and model_size < mem_dev:
631
- return torch_dev
632
- else:
633
- return cpu_dev
634
-
635
- def maximum_vram_for_weights(device=None):
636
- return (get_total_memory(device) * 0.88 - minimum_inference_memory())
637
-
638
- def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
639
- if model_params < 0:
640
- model_params = 1000000000000000000000
641
- if args.fp32_unet:
642
- return torch.float32
643
- if args.fp64_unet:
644
- return torch.float64
645
- if args.bf16_unet:
646
- return torch.bfloat16
647
- if args.fp16_unet:
648
- return torch.float16
649
- if args.fp8_e4m3fn_unet:
650
- return torch.float8_e4m3fn
651
- if args.fp8_e5m2_unet:
652
- return torch.float8_e5m2
653
-
654
- fp8_dtype = None
655
- try:
656
- for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
657
- if dtype in supported_dtypes:
658
- fp8_dtype = dtype
659
- break
660
- except:
661
- pass
662
-
663
- if fp8_dtype is not None:
664
- if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
665
- return fp8_dtype
666
-
667
- free_model_memory = maximum_vram_for_weights(device)
668
- if model_params * 2 > free_model_memory:
669
- return fp8_dtype
670
-
671
- for dt in supported_dtypes:
672
- if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
673
- if torch.float16 in supported_dtypes:
674
- return torch.float16
675
- if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
676
- if torch.bfloat16 in supported_dtypes:
677
- return torch.bfloat16
678
-
679
- for dt in supported_dtypes:
680
- if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
681
- if torch.float16 in supported_dtypes:
682
- return torch.float16
683
- if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
684
- if torch.bfloat16 in supported_dtypes:
685
- return torch.bfloat16
686
-
687
- return torch.float32
688
-
689
- # None means no manual cast
690
- def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
691
- if weight_dtype == torch.float32 or weight_dtype == torch.float64:
692
- return None
693
-
694
- fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
695
- if fp16_supported and weight_dtype == torch.float16:
696
- return None
697
-
698
- bf16_supported = should_use_bf16(inference_device)
699
- if bf16_supported and weight_dtype == torch.bfloat16:
700
- return None
701
-
702
- fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
703
- for dt in supported_dtypes:
704
- if dt == torch.float16 and fp16_supported:
705
- return torch.float16
706
- if dt == torch.bfloat16 and bf16_supported:
707
- return torch.bfloat16
708
-
709
- return torch.float32
710
-
711
- def text_encoder_offload_device():
712
- if args.gpu_only:
713
- return get_torch_device()
714
- else:
715
- return torch.device("cpu")
716
-
717
- def text_encoder_device():
718
- if args.gpu_only:
719
- return get_torch_device()
720
- elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
721
- if should_use_fp16(prioritize_performance=False):
722
- return get_torch_device()
723
- else:
724
- return torch.device("cpu")
725
- else:
726
- return torch.device("cpu")
727
-
728
- def text_encoder_initial_device(load_device, offload_device, model_size=0):
729
- if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
730
- return offload_device
731
-
732
- if is_device_mps(load_device):
733
- return load_device
734
-
735
- mem_l = get_free_memory(load_device)
736
- mem_o = get_free_memory(offload_device)
737
- if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
738
- return load_device
739
- else:
740
- return offload_device
741
-
742
- def text_encoder_dtype(device=None):
743
- if args.fp8_e4m3fn_text_enc:
744
- return torch.float8_e4m3fn
745
- elif args.fp8_e5m2_text_enc:
746
- return torch.float8_e5m2
747
- elif args.fp16_text_enc:
748
- return torch.float16
749
- elif args.fp32_text_enc:
750
- return torch.float32
751
-
752
- if is_device_cpu(device):
753
- return torch.float16
754
-
755
- return torch.float16
756
-
757
-
758
- def intermediate_device():
759
- if args.gpu_only:
760
- return get_torch_device()
761
- else:
762
- return torch.device("cpu")
763
-
764
- def vae_device():
765
- if args.cpu_vae:
766
- return torch.device("cpu")
767
- return get_torch_device()
768
-
769
- def vae_offload_device():
770
- if args.gpu_only:
771
- return get_torch_device()
772
- else:
773
- return torch.device("cpu")
774
-
775
- def vae_dtype(device=None, allowed_dtypes=[]):
776
- if args.fp16_vae:
777
- return torch.float16
778
- elif args.bf16_vae:
779
- return torch.bfloat16
780
- elif args.fp32_vae:
781
- return torch.float32
782
-
783
- for d in allowed_dtypes:
784
- if d == torch.float16 and should_use_fp16(device):
785
- return d
786
-
787
- # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
788
- if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
789
- return d
790
-
791
- return torch.float32
792
-
793
- def get_autocast_device(dev):
794
- if hasattr(dev, 'type'):
795
- return dev.type
796
- return "cuda"
797
-
798
- def supports_dtype(device, dtype): #TODO
799
- if dtype == torch.float32:
800
- return True
801
- if is_device_cpu(device):
802
- return False
803
- if dtype == torch.float16:
804
- return True
805
- if dtype == torch.bfloat16:
806
- return True
807
- return False
808
-
809
- def supports_cast(device, dtype): #TODO
810
- if dtype == torch.float32:
811
- return True
812
- if dtype == torch.float16:
813
- return True
814
- if directml_enabled: #TODO: test this
815
- return False
816
- if dtype == torch.bfloat16:
817
- return True
818
- if is_device_mps(device):
819
- return False
820
- if dtype == torch.float8_e4m3fn:
821
- return True
822
- if dtype == torch.float8_e5m2:
823
- return True
824
- return False
825
-
826
- def pick_weight_dtype(dtype, fallback_dtype, device=None):
827
- if dtype is None:
828
- dtype = fallback_dtype
829
- elif dtype_size(dtype) > dtype_size(fallback_dtype):
830
- dtype = fallback_dtype
831
-
832
- if not supports_cast(device, dtype):
833
- dtype = fallback_dtype
834
-
835
- return dtype
836
-
837
- def device_supports_non_blocking(device):
838
- if is_device_mps(device):
839
- return False #pytorch bug? mps doesn't support non blocking
840
- if is_intel_xpu():
841
- return False
842
- if args.deterministic: #TODO: figure out why deterministic breaks non blocking from gpu to cpu (previews)
843
- return False
844
- if directml_enabled:
845
- return False
846
- return True
847
-
848
- def device_should_use_non_blocking(device):
849
- if not device_supports_non_blocking(device):
850
- return False
851
- return False
852
- # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
853
-
854
- def force_channels_last():
855
- if args.force_channels_last:
856
- return True
857
-
858
- #TODO
859
- return False
860
-
861
- def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
862
- if device is None or weight.device == device:
863
- if not copy:
864
- if dtype is None or weight.dtype == dtype:
865
- return weight
866
- return weight.to(dtype=dtype, copy=copy)
867
-
868
- r = torch.empty_like(weight, dtype=dtype, device=device)
869
- r.copy_(weight, non_blocking=non_blocking)
870
- return r
871
-
872
- def cast_to_device(tensor, device, dtype, copy=False):
873
- non_blocking = device_supports_non_blocking(device)
874
- return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
875
-
876
- def sage_attention_enabled():
877
- return args.use_sage_attention
878
-
879
- def xformers_enabled():
880
- global directml_enabled
881
- global cpu_state
882
- if cpu_state != CPUState.GPU:
883
- return False
884
- if is_intel_xpu():
885
- return False
886
- if is_ascend_npu():
887
- return False
888
- if directml_enabled:
889
- return False
890
- return XFORMERS_IS_AVAILABLE
891
-
892
-
893
- def xformers_enabled_vae():
894
- enabled = xformers_enabled()
895
- if not enabled:
896
- return False
897
-
898
- return XFORMERS_ENABLED_VAE
899
-
900
- def pytorch_attention_enabled():
901
- global ENABLE_PYTORCH_ATTENTION
902
- return ENABLE_PYTORCH_ATTENTION
903
-
904
- def pytorch_attention_flash_attention():
905
- global ENABLE_PYTORCH_ATTENTION
906
- if ENABLE_PYTORCH_ATTENTION:
907
- #TODO: more reliable way of checking for flash attention?
908
- if is_nvidia(): #pytorch flash attention only works on Nvidia
909
- return True
910
- if is_intel_xpu():
911
- return True
912
- if is_ascend_npu():
913
- return True
914
- return False
915
-
916
- def mac_version():
917
- try:
918
- return tuple(int(n) for n in platform.mac_ver()[0].split("."))
919
- except:
920
- return None
921
-
922
- def force_upcast_attention_dtype():
923
- upcast = args.force_upcast_attention
924
-
925
- macos_version = mac_version()
926
- if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
927
- upcast = True
928
-
929
- if upcast:
930
- return torch.float32
931
- else:
932
- return None
933
-
934
- def get_free_memory(dev=None, torch_free_too=False):
935
- global directml_enabled
936
- if dev is None:
937
- dev = get_torch_device()
938
-
939
- if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
940
- mem_free_total = psutil.virtual_memory().available
941
- mem_free_torch = mem_free_total
942
- else:
943
- if directml_enabled:
944
- mem_free_total = 1024 * 1024 * 1024 #TODO
945
- mem_free_torch = mem_free_total
946
- elif is_intel_xpu():
947
- stats = torch.xpu.memory_stats(dev)
948
- mem_active = stats['active_bytes.all.current']
949
- mem_reserved = stats['reserved_bytes.all.current']
950
- mem_free_torch = mem_reserved - mem_active
951
- mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
952
- mem_free_total = mem_free_xpu + mem_free_torch
953
- elif is_ascend_npu():
954
- stats = torch.npu.memory_stats(dev)
955
- mem_active = stats['active_bytes.all.current']
956
- mem_reserved = stats['reserved_bytes.all.current']
957
- mem_free_npu, _ = torch.npu.mem_get_info(dev)
958
- mem_free_torch = mem_reserved - mem_active
959
- mem_free_total = mem_free_npu + mem_free_torch
960
- else:
961
- stats = torch.cuda.memory_stats(dev)
962
- mem_active = stats['active_bytes.all.current']
963
- mem_reserved = stats['reserved_bytes.all.current']
964
- mem_free_cuda, _ = torch.cuda.mem_get_info(dev)
965
- mem_free_torch = mem_reserved - mem_active
966
- mem_free_total = mem_free_cuda + mem_free_torch
967
-
968
- if torch_free_too:
969
- return (mem_free_total, mem_free_torch)
970
- else:
971
- return mem_free_total
972
-
973
- def cpu_mode():
974
- global cpu_state
975
- return cpu_state == CPUState.CPU
976
-
977
- def mps_mode():
978
- global cpu_state
979
- return cpu_state == CPUState.MPS
980
-
981
- def is_device_type(device, type):
982
- if hasattr(device, 'type'):
983
- if (device.type == type):
984
- return True
985
- return False
986
-
987
- def is_device_cpu(device):
988
- return is_device_type(device, 'cpu')
989
-
990
- def is_device_mps(device):
991
- return is_device_type(device, 'mps')
992
-
993
- def is_device_cuda(device):
994
- return is_device_type(device, 'cuda')
995
-
996
- def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
997
- global directml_enabled
998
-
999
- if device is not None:
1000
- if is_device_cpu(device):
1001
- return False
1002
-
1003
- if FORCE_FP16:
1004
- return True
1005
-
1006
- if FORCE_FP32:
1007
- return False
1008
-
1009
- if directml_enabled:
1010
- return False
1011
-
1012
- if (device is not None and is_device_mps(device)) or mps_mode():
1013
- return True
1014
-
1015
- if cpu_mode():
1016
- return False
1017
-
1018
- if is_intel_xpu():
1019
- return True
1020
-
1021
- if is_ascend_npu():
1022
- return True
1023
-
1024
- if torch.version.hip:
1025
- return True
1026
-
1027
- props = torch.cuda.get_device_properties(device)
1028
- if props.major >= 8:
1029
- return True
1030
-
1031
- if props.major < 6:
1032
- return False
1033
-
1034
- #FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
1035
- nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
1036
- for x in nvidia_10_series:
1037
- if x in props.name.lower():
1038
- if WINDOWS or manual_cast:
1039
- return True
1040
- else:
1041
- return False #weird linux behavior where fp32 is faster
1042
-
1043
- if manual_cast:
1044
- free_model_memory = maximum_vram_for_weights(device)
1045
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1046
- return True
1047
-
1048
- if props.major < 7:
1049
- return False
1050
-
1051
- #FP16 is just broken on these cards
1052
- nvidia_16_series = ["1660", "1650", "1630", "T500", "T550", "T600", "MX550", "MX450", "CMP 30HX", "T2000", "T1000", "T1200"]
1053
- for x in nvidia_16_series:
1054
- if x in props.name:
1055
- return False
1056
-
1057
- return True
1058
-
1059
- def should_use_bf16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
1060
- if device is not None:
1061
- if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
1062
- return False
1063
-
1064
- if FORCE_FP32:
1065
- return False
1066
-
1067
- if directml_enabled:
1068
- return False
1069
-
1070
- if (device is not None and is_device_mps(device)) or mps_mode():
1071
- if mac_version() < (14,):
1072
- return False
1073
- return True
1074
-
1075
- if cpu_mode():
1076
- return False
1077
-
1078
- if is_intel_xpu():
1079
- return True
1080
-
1081
- props = torch.cuda.get_device_properties(device)
1082
- if props.major >= 8:
1083
- return True
1084
-
1085
- bf16_works = torch.cuda.is_bf16_supported()
1086
-
1087
- if bf16_works or manual_cast:
1088
- free_model_memory = maximum_vram_for_weights(device)
1089
- if (not prioritize_performance) or model_params * 4 > free_model_memory:
1090
- return True
1091
-
1092
- return False
1093
-
1094
- def supports_fp8_compute(device=None):
1095
- if not is_nvidia():
1096
- return False
1097
-
1098
- props = torch.cuda.get_device_properties(device)
1099
- if props.major >= 9:
1100
- return True
1101
- if props.major < 8:
1102
- return False
1103
- if props.minor < 9:
1104
- return False
1105
-
1106
- if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
1107
- return False
1108
-
1109
- if WINDOWS:
1110
- if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
1111
- return False
1112
-
1113
- return True
1114
-
1115
- def soft_empty_cache(force=False):
1116
- global cpu_state
1117
- if cpu_state == CPUState.MPS:
1118
- torch.mps.empty_cache()
1119
- elif is_intel_xpu():
1120
- torch.xpu.empty_cache()
1121
- elif is_ascend_npu():
1122
- torch.npu.empty_cache()
1123
- elif torch.cuda.is_available():
1124
- torch.cuda.empty_cache()
1125
- torch.cuda.ipc_collect()
1126
-
1127
- def unload_all_models():
1128
- free_memory(1e30, get_torch_device())
1129
-
1130
-
1131
- #TODO: might be cleaner to put this somewhere else
1132
- import threading
1133
-
1134
- class InterruptProcessingException(Exception):
1135
- pass
1136
-
1137
- interrupt_processing_mutex = threading.RLock()
1138
-
1139
- interrupt_processing = False
1140
- def interrupt_current_processing(value=True):
1141
- global interrupt_processing
1142
- global interrupt_processing_mutex
1143
- with interrupt_processing_mutex:
1144
- interrupt_processing = value
1145
-
1146
- def processing_interrupted():
1147
- global interrupt_processing
1148
- global interrupt_processing_mutex
1149
- with interrupt_processing_mutex:
1150
- return interrupt_processing
1151
-
1152
- def throw_exception_if_processing_interrupted():
1153
- global interrupt_processing
1154
- global interrupt_processing_mutex
1155
- with interrupt_processing_mutex:
1156
- if interrupt_processing:
1157
- interrupt_processing = False
1158
- raise InterruptProcessingException()