from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline from transformers import CLIPVisionModelWithProjection import torch from copy import deepcopy ENABLE_CPU_CACHE = False DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5" cached_models = {} # cache for models to avoid repeated loading, key is model name def cache_model(func): def wrapper(*args, **kwargs): if ENABLE_CPU_CACHE: model_name = func.__name__ + str(args) + str(kwargs) if model_name not in cached_models: cached_models[model_name] = func(*args, **kwargs) return cached_models[model_name] else: return func(*args, **kwargs) return wrapper def copied_cache_model(func): def wrapper(*args, **kwargs): if ENABLE_CPU_CACHE: model_name = func.__name__ + str(args) + str(kwargs) if model_name not in cached_models: cached_models[model_name] = func(*args, **kwargs) return deepcopy(cached_models[model_name]) else: return func(*args, **kwargs) return wrapper def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs): if ckpt_or_pretrained.endswith(".safetensors"): pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs) else: pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs) return pipe @copied_cache_model def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16): model_kwargs = dict( torch_dtype=torch_dtype, requires_safety_checker=False, safety_checker=None, ) pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( base_model, StableDiffusionPipeline, **model_kwargs ) pipe.to("cpu") return pipe.components @cache_model def load_controlnet(controlnet_path, torch_dtype=torch.float16): controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype) return controlnet @cache_model def load_image_encoder(): image_encoder = CLIPVisionModelWithProjection.from_pretrained( "h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16, ) return image_encoder def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs): model_kwargs = dict( torch_dtype=torch_dtype, device_map=device, requires_safety_checker=False, safety_checker=None, ) components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype) model_kwargs.update(components) model_kwargs.update(kwargs) if controlnet is not None: if isinstance(controlnet, list): controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet] else: controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype) model_kwargs.update(controlnet=controlnet) if pipeline_class is None: if controlnet is not None: pipeline_class = StableDiffusionControlNetPipeline else: pipeline_class = StableDiffusionPipeline pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained( base_model, pipeline_class, **model_kwargs ) if ip_adapter: image_encoder = load_image_encoder() pipe.image_encoder = image_encoder if plus_model: pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors") else: pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors") pipe.set_ip_adapter_scale(1.0) else: pipe.unload_ip_adapter() pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) if model_cpu_offload_seq is None: if isinstance(pipe, StableDiffusionControlNetPipeline): pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae" elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline): pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae" else: pipe.model_cpu_offload_seq = model_cpu_offload_seq if enable_sequential_cpu_offload: pipe.enable_sequential_cpu_offload() else: pass pipe.enable_model_cpu_offload() if vae_slicing: pipe.enable_vae_slicing() import gc gc.collect() return pipe