import torch from DeepCache import DeepCacheSDHelper from diffusers.models import AutoencoderKL from .config import Config from .logger import Logger from .upscaler import RealESRGAN from .utils import timer class Loader: def __init__(self): self.model = "" self.vae = None self.refiner = None self.pipeline = None self.upscaler = None self.log = Logger("Loader") self.device = torch.device("cuda") # always called in CUDA context def should_unload_deepcache(self, cache_interval=1): has_deepcache = hasattr(self.pipeline, "deepcache") if has_deepcache and cache_interval == 1: return True if has_deepcache and self.pipeline.deepcache.params["cache_interval"] != cache_interval: return True return False def should_unload_upscaler(self, scale=1): return self.upscaler is not None and self.upscaler.scale != scale def should_unload_refiner(self, use_refiner=False): return self.refiner is not None and not use_refiner def should_unload_pipeline(self, model=""): return self.pipeline is not None and self.model != model def should_load_deepcache(self, cache_interval=1): has_deepcache = hasattr(self.pipeline, "deepcache") if not has_deepcache and cache_interval > 1: return True return False def should_load_upscaler(self, scale=1): return self.upscaler is None and scale > 1 def should_load_refiner(self, use_refiner=False): return self.refiner is None and use_refiner def should_load_pipeline(self, pipeline_id=""): if self.pipeline is None: return True if not isinstance(self.pipeline, Config.PIPELINES[pipeline_id]): return True return False def should_load_scheduler(self, cls, use_karras=False): has_karras = hasattr(self.pipeline.scheduler.config, "use_karras_sigmas") if not isinstance(self.pipeline.scheduler, cls): return True if has_karras and self.pipeline.scheduler.config.use_karras_sigmas != use_karras: return True return False def unload_all(self, model, deepcache_interval, scale, use_refiner): if self.should_unload_deepcache(deepcache_interval): self.log.info("Disabling DeepCache") self.pipeline.deepcache.disable() delattr(self.pipeline, "deepcache") if self.refiner: self.refiner.deepcache.disable() delattr(self.refiner, "deepcache") if self.should_unload_upscaler(scale): self.log.info("Unloading upscaler") self.upscaler = None if self.should_unload_refiner(use_refiner): self.log.info("Unloading refiner") self.refiner = None if self.should_unload_pipeline(model): self.log.info(f"Unloading {self.model}") if self.refiner: self.refiner.vae = None self.refiner.scheduler = None self.refiner.tokenizer_2 = None self.refiner.text_encoder_2 = None self.pipeline = None self.model = "" def load_deepcache(self, interval=1): self.log.info("Enabling DeepCache") self.pipeline.deepcache = DeepCacheSDHelper(pipe=self.pipeline) self.pipeline.deepcache.set_params(cache_interval=interval) self.pipeline.deepcache.enable() if self.refiner: self.refiner.deepcache = DeepCacheSDHelper(pipe=self.refiner) self.refiner.deepcache.set_params(cache_interval=interval) self.refiner.deepcache.enable() def load_upscaler(self, scale=1): with timer(f"Loading {scale}x upscaler", logger=self.log.info): self.upscaler = RealESRGAN(scale, device=self.device) self.upscaler.load_weights() def load_refiner(self): model = Config.REFINER_MODEL with timer(f"Loading {model}", logger=self.log.info): refiner_kwargs = { "variant": "fp16", "torch_dtype": self.pipeline.dtype, "add_watermarker": False, "requires_aesthetics_score": True, "force_zeros_for_empty_prompt": False, "vae": self.pipeline.vae, "scheduler": self.pipeline.scheduler, "tokenizer_2": self.pipeline.tokenizer_2, "text_encoder_2": self.pipeline.text_encoder_2, } Pipeline = Config.PIPELINES["img2img"] self.refiner = Pipeline.from_pretrained(model, **refiner_kwargs).to(self.device) self.refiner.set_progress_bar_config(disable=True) def load_pipeline(self, pipeline_id, model, **kwargs): Pipeline = Config.PIPELINES[pipeline_id] # Load VAE first if self.vae is None: self.vae = AutoencoderKL.from_pretrained( Config.VAE_MODEL, torch_dtype=torch.float32, # vae is full-precision ).to(self.device) kwargs["vae"] = self.vae # Load from scratch if self.pipeline is None: with timer(f"Loading {model} ({pipeline_id})", logger=self.log.info): if model in Config.SINGLE_FILE_MODELS: checkpoint = Config.HF_REPOS[model][0] self.pipeline = Pipeline.from_single_file( f"https://huggingface.co/{model}/{checkpoint}", **kwargs, ).to(self.device) else: self.pipeline = Pipeline.from_pretrained(model, **kwargs).to(self.device) # Change to a different one else: with timer(f"Changing pipeline to {pipeline_id}", logger=self.log.info): self.pipeline = Pipeline.from_pipe(self.pipeline).to(self.device) # Update model and disable terminal progress bars self.model = model self.pipeline.set_progress_bar_config(disable=True) def load_scheduler(self, cls, use_karras=False, **kwargs): self.log.info(f"Loading {cls.__name__}{' with Karras' if use_karras else ''}") self.pipeline.scheduler = cls(**kwargs) if self.refiner is not None: self.refiner.scheduler = self.pipeline.scheduler def load(self, pipeline_id, model, scheduler, deepcache_interval, scale, use_karras, use_refiner): Scheduler = Config.SCHEDULERS[scheduler] scheduler_kwargs = { "beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear", "timestep_spacing": "leading", "steps_offset": 1, } if scheduler not in ["Euler a"]: scheduler_kwargs["use_karras_sigmas"] = use_karras pipeline_kwargs = { "torch_dtype": torch.float16, "add_watermarker": False, "scheduler": Scheduler(**scheduler_kwargs), } # Single-file models don't need a variant if model not in Config.SINGLE_FILE_MODELS: pipeline_kwargs["variant"] = "fp16" else: pipeline_kwargs["variant"] = None # Unload self.unload_all(model, deepcache_interval, scale, use_refiner) # Load if self.should_load_pipeline(pipeline_id): self.load_pipeline(pipeline_id, model, **pipeline_kwargs) if self.should_load_refiner(use_refiner): self.load_refiner() if self.should_load_scheduler(Scheduler, use_karras): self.load_scheduler(Scheduler, use_karras, **scheduler_kwargs) if self.should_load_deepcache(deepcache_interval): self.load_deepcache(deepcache_interval) if self.should_load_upscaler(scale): self.load_upscaler(scale) # Get a singleton or a new instance of the Loader def get_loader(singleton=False): if not singleton: return Loader() else: if not hasattr(get_loader, "_instance"): get_loader._instance = Loader() assert isinstance(get_loader._instance, Loader) return get_loader._instance