Unique3D / scripts /sd_model_zoo.py
Wuvin's picture
add offload
8981664
raw
history blame
5.1 kB
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, EulerAncestralDiscreteScheduler, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionPipeline
from transformers import CLIPVisionModelWithProjection
import torch
from copy import deepcopy
ENABLE_CPU_CACHE = False
DEFAULT_BASE_MODEL = "runwayml/stable-diffusion-v1-5"
cached_models = {} # cache for models to avoid repeated loading, key is model name
def cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return cached_models[model_name]
else:
return func(*args, **kwargs)
return wrapper
def copied_cache_model(func):
def wrapper(*args, **kwargs):
if ENABLE_CPU_CACHE:
model_name = func.__name__ + str(args) + str(kwargs)
if model_name not in cached_models:
cached_models[model_name] = func(*args, **kwargs)
return deepcopy(cached_models[model_name])
else:
return func(*args, **kwargs)
return wrapper
def model_from_ckpt_or_pretrained(ckpt_or_pretrained, model_cls, original_config_file='ckpt/v1-inference.yaml', torch_dtype=torch.float16, **kwargs):
if ckpt_or_pretrained.endswith(".safetensors"):
pipe = model_cls.from_single_file(ckpt_or_pretrained, original_config_file=original_config_file, torch_dtype=torch_dtype, **kwargs)
else:
pipe = model_cls.from_pretrained(ckpt_or_pretrained, torch_dtype=torch_dtype, **kwargs)
return pipe
@copied_cache_model
def load_base_model_components(base_model=DEFAULT_BASE_MODEL, torch_dtype=torch.float16):
model_kwargs = dict(
torch_dtype=torch_dtype,
requires_safety_checker=False,
safety_checker=None,
)
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
StableDiffusionPipeline,
**model_kwargs
)
pipe.to("cpu")
return pipe.components
@cache_model
def load_controlnet(controlnet_path, torch_dtype=torch.float16):
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch_dtype)
return controlnet
@cache_model
def load_image_encoder():
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"h94/IP-Adapter",
subfolder="models/image_encoder",
torch_dtype=torch.float16,
)
return image_encoder
def load_common_sd15_pipe(base_model=DEFAULT_BASE_MODEL, device="balanced", controlnet=None, ip_adapter=False, plus_model=True, torch_dtype=torch.float16, model_cpu_offload_seq=None, enable_sequential_cpu_offload=False, vae_slicing=False, pipeline_class=None, **kwargs):
model_kwargs = dict(
torch_dtype=torch_dtype,
device_map=device,
requires_safety_checker=False,
safety_checker=None,
)
components = load_base_model_components(base_model=base_model, torch_dtype=torch_dtype)
model_kwargs.update(components)
model_kwargs.update(kwargs)
if controlnet is not None:
if isinstance(controlnet, list):
controlnet = [load_controlnet(controlnet_path, torch_dtype=torch_dtype) for controlnet_path in controlnet]
else:
controlnet = load_controlnet(controlnet, torch_dtype=torch_dtype)
model_kwargs.update(controlnet=controlnet)
if pipeline_class is None:
if controlnet is not None:
pipeline_class = StableDiffusionControlNetPipeline
else:
pipeline_class = StableDiffusionPipeline
pipe: StableDiffusionPipeline = model_from_ckpt_or_pretrained(
base_model,
pipeline_class,
**model_kwargs
)
if ip_adapter:
image_encoder = load_image_encoder()
pipe.image_encoder = image_encoder
if plus_model:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.safetensors")
else:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.safetensors")
pipe.set_ip_adapter_scale(1.0)
else:
pipe.unload_ip_adapter()
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
if model_cpu_offload_seq is None:
if isinstance(pipe, StableDiffusionControlNetPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->unet->vae"
elif isinstance(pipe, StableDiffusionControlNetImg2ImgPipeline):
pipe.model_cpu_offload_seq = "text_encoder->controlnet->vae->unet->vae"
else:
pipe.model_cpu_offload_seq = model_cpu_offload_seq
if enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload()
else:
pass
pipe.enable_model_cpu_offload()
if vae_slicing:
pipe.enable_vae_slicing()
import gc
gc.collect()
return pipe