diffusion / lib /loader.py
adamelliotfields's picture
Update models
4719a50 verified
raw
history blame
12.3 kB
import gc
from threading import Lock
import torch
from DeepCache import DeepCacheSDHelper
from diffusers import ControlNetModel
from diffusers.models.attention_processor import AttnProcessor2_0, IPAdapterAttnProcessor2_0
from .config import Config
from .logger import Logger
from .upscaler import RealESRGAN
from .utils import clear_cuda_cache, safe_progress, timer
class Loader:
_instance = None
_lock = Lock()
def __new__(cls):
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.pipe = None
cls._instance.model = None
cls._instance.upscaler = None
cls._instance.controlnet = None
cls._instance.ip_adapter = None
cls._instance.log = Logger("Loader")
return cls._instance
def _should_unload_upscaler(self, scale=1):
if self.upscaler is not None and self.upscaler.scale != scale:
return True
return False
def _should_unload_deepcache(self, interval=1):
has_deepcache = hasattr(self.pipe, "deepcache")
if has_deepcache and interval == 1:
return True
if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
return True
return False
def _should_unload_ip_adapter(self, model="", ip_adapter=""):
# unload if model changed
if self.model and self.model.lower() != model.lower():
return True
if self.ip_adapter and not ip_adapter:
return True
return False
def _should_unload_controlnet(self, kind="", controlnet=""):
if self.controlnet is None:
return False
if self.controlnet.lower() != controlnet.lower():
return True
if not kind.startswith("controlnet_"):
return True
return False
def _should_unload_pipeline(self, kind="", model="", controlnet=""):
if self.pipe is None:
return False
if self.model.lower() != model.lower():
return True
if kind == "txt2img" and not isinstance(self.pipe, Config.PIPELINES["txt2img"]):
return True
if kind == "img2img" and not isinstance(self.pipe, Config.PIPELINES["img2img"]):
return True
if kind == "controlnet_txt2img" and not isinstance(
self.pipe,
Config.PIPELINES["controlnet_txt2img"],
):
return True
if kind == "controlnet_img2img" and not isinstance(
self.pipe,
Config.PIPELINES["controlnet_img2img"],
):
return True
if self._should_unload_controlnet(kind, controlnet):
return True
return False
def _unload_upscaler(self):
if self.upscaler is not None:
with timer(f"Unloading {self.upscaler.scale}x upscaler", logger=self.log.info):
self.upscaler.to("cpu")
def _unload_deepcache(self):
if self.pipe.deepcache is not None:
self.log.info("Disabling DeepCache")
self.pipe.deepcache.disable()
delattr(self.pipe, "deepcache")
# Copied from https://github.com/huggingface/diffusers/blob/v0.28.0/src/diffusers/loaders/ip_adapter.py#L300
def _unload_ip_adapter(self):
if self.ip_adapter is not None:
with timer("Unloading IP-Adapter", logger=self.log.info):
if not isinstance(self.pipe, Config.PIPELINES["img2img"]):
self.pipe.image_encoder = None
self.pipe.register_to_config(image_encoder=[None, None])
self.pipe.feature_extractor = None
self.pipe.unet.encoder_hid_proj = None
self.pipe.unet.config.encoder_hid_dim_type = None
self.pipe.register_to_config(feature_extractor=[None, None])
attn_procs = {}
for name, value in self.pipe.unet.attn_processors.items():
attn_processor_class = AttnProcessor2_0() # raises if not torch 2
attn_procs[name] = (
attn_processor_class
if isinstance(value, IPAdapterAttnProcessor2_0)
else value.__class__()
)
self.pipe.unet.set_attn_processor(attn_procs)
def _unload_pipeline(self):
if self.pipe is not None:
with timer(f"Unloading {self.model}", logger=self.log.info):
self.pipe.to("cpu")
def _unload(
self,
kind="",
model="",
controlnet="",
ip_adapter="",
deepcache=1,
scale=1,
):
to_unload = []
if self._should_unload_deepcache(deepcache): # remove deepcache first
self._unload_deepcache()
if self._should_unload_upscaler(scale):
self._unload_upscaler()
to_unload.append("upscaler")
if self._should_unload_ip_adapter(model, ip_adapter):
self._unload_ip_adapter()
to_unload.append("ip_adapter")
if self._should_unload_controlnet(kind, controlnet):
to_unload.append("controlnet")
if self._should_unload_pipeline(kind, model, controlnet):
self._unload_pipeline()
to_unload.append("model")
to_unload.append("pipe")
# Flush cache and run garbage collector
clear_cuda_cache()
for component in to_unload:
setattr(self, component, None)
gc.collect()
def _should_load_upscaler(self, scale=1):
if self.upscaler is None and scale > 1:
return True
return False
def _should_load_deepcache(self, interval=1):
has_deepcache = hasattr(self.pipe, "deepcache")
if not has_deepcache and interval != 1:
return True
if has_deepcache and self.pipe.deepcache.params["cache_interval"] != interval:
return True
return False
def _should_load_ip_adapter(self, ip_adapter=""):
if not self.ip_adapter and ip_adapter:
return True
return False
def _should_load_pipeline(self):
if self.pipe is None:
return True
return False
def _load_upscaler(self, scale=1):
if self._should_load_upscaler(scale):
try:
msg = f"Loading {scale}x upscaler"
with timer(msg, logger=self.log.info):
self.upscaler = RealESRGAN(scale, device=self.pipe.device)
self.upscaler.load_weights()
except Exception as e:
self.log.error(f"Error loading {scale}x upscaler: {e}")
self.upscaler = None
def _load_deepcache(self, interval=1):
if self._should_load_deepcache(interval):
self.log.info("Enabling DeepCache")
self.pipe.deepcache = DeepCacheSDHelper(self.pipe)
self.pipe.deepcache.set_params(cache_interval=interval)
self.pipe.deepcache.enable()
def _load_ip_adapter(self, ip_adapter=""):
if self._should_load_ip_adapter(ip_adapter):
msg = "Loading IP-Adapter"
with timer(msg, logger=self.log.info):
self.pipe.load_ip_adapter(
"h94/IP-Adapter",
subfolder="models",
weight_name=f"ip-adapter-{ip_adapter}_sd15.safetensors",
)
# 50% works the best
self.pipe.set_ip_adapter_scale(0.5)
self.ip_adapter = ip_adapter
def _load_pipeline(
self,
kind,
model,
progress,
**kwargs,
):
pipeline = Config.PIPELINES[kind]
if self._should_load_pipeline():
try:
with timer(f"Loading {model} ({kind})", logger=self.log.info):
self.model = model
if model.lower() in Config.MODEL_CHECKPOINTS.keys():
self.pipe = pipeline.from_single_file(
f"https://huggingface.co/{model}/{Config.MODEL_CHECKPOINTS[model.lower()]}",
progress,
**kwargs,
).to("cuda")
else:
self.pipe = pipeline.from_pretrained(model, progress, **kwargs).to("cuda")
except Exception as e:
self.log.error(f"Error loading {model}: {e}")
self.model = None
self.pipe = None
return
if not isinstance(self.pipe, pipeline):
self.pipe = pipeline.from_pipe(self.pipe).to("cuda")
if self.pipe is not None:
self.pipe.set_progress_bar_config(disable=progress is not None)
def load(
self,
kind,
ip_adapter,
model,
scheduler,
annotator,
deepcache,
scale,
karras,
progress,
):
scheduler_kwargs = {
"beta_schedule": "scaled_linear",
"timestep_spacing": "leading",
"beta_start": 0.00085,
"beta_end": 0.012,
"steps_offset": 1,
}
if scheduler not in ["DDIM", "Euler a", "PNDM"]:
scheduler_kwargs["use_karras_sigmas"] = karras
# https://github.com/huggingface/diffusers/blob/8a3f0c1/scripts/convert_original_stable_diffusion_to_diffusers.py#L939
if scheduler == "DDIM":
scheduler_kwargs["clip_sample"] = False
scheduler_kwargs["set_alpha_to_one"] = False
pipe_kwargs = {
"safety_checker": None,
"requires_safety_checker": False,
"scheduler": Config.SCHEDULERS[scheduler](**scheduler_kwargs),
}
# diffusers fp16 variant
if model.lower() not in Config.MODEL_CHECKPOINTS.keys():
pipe_kwargs["variant"] = "fp16"
else:
pipe_kwargs["variant"] = None
# converts to fp32 by default
pipe_kwargs["torch_dtype"] = torch.float16
# config maps the repo to the ID: canny -> lllyasviel/control_sd15_canny
if kind.startswith("controlnet_"):
pipe_kwargs["controlnet"] = ControlNetModel.from_pretrained(
Config.ANNOTATORS[annotator],
torch_dtype=torch.float16,
variant="fp16",
)
self.controlnet = annotator
self._unload(kind, model, annotator, ip_adapter, deepcache, scale)
self._load_pipeline(kind, model, progress, **pipe_kwargs)
# error loading model
if self.pipe is None:
return
same_scheduler = isinstance(self.pipe.scheduler, Config.SCHEDULERS[scheduler])
same_karras = (
not hasattr(self.pipe.scheduler.config, "use_karras_sigmas")
or self.pipe.scheduler.config.use_karras_sigmas == karras
)
# same model, different scheduler
if self.model.lower() == model.lower():
if not same_scheduler:
self.log.info(f"Enabling {scheduler} scheduler")
if not same_karras:
self.log.info(f"{'Enabling' if karras else 'Disabling'} Karras sigmas")
if not same_scheduler or not same_karras:
self.pipe.scheduler = Config.SCHEDULERS[scheduler](**scheduler_kwargs)
CURRENT_STEP = 1
TOTAL_STEPS = sum(
[
self._should_load_deepcache(deepcache),
self._should_load_ip_adapter(ip_adapter),
self._should_load_upscaler(scale),
]
)
desc = "Configuring pipeline"
if self._should_load_deepcache(deepcache):
self._load_deepcache(deepcache)
safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
CURRENT_STEP += 1
if self._should_load_ip_adapter(ip_adapter):
self._load_ip_adapter(ip_adapter)
safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)
CURRENT_STEP += 1
if self._should_load_upscaler(scale):
self._load_upscaler(scale)
safe_progress(progress, CURRENT_STEP, TOTAL_STEPS, desc)