diffusion / lib /pipelines.py
adamelliotfields's picture
Progress bar for loading pipeline
1e250ff verified
raw
history blame
9.6 kB
import os
from importlib import import_module
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from diffusers.loaders.single_file import (
SINGLE_FILE_OPTIONAL_COMPONENTS,
load_single_file_sub_model,
)
from diffusers.loaders.single_file_utils import fetch_diffusers_config, load_single_file_checkpoint
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from diffusers.pipelines.pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
_get_pipeline_class,
load_sub_model,
)
from diffusers.utils import logging
from huggingface_hub import snapshot_download
from huggingface_hub.utils import validate_hf_hub_args
class CustomDiffusionMixin:
r"""
Overrides DiffusionPipeline methods.
"""
# Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/pipelines/pipeline_utils.py#L480
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path, progress=None, **kwargs):
torch_dtype = kwargs.pop("torch_dtype", None)
variant = kwargs.pop("variant", None)
token = kwargs.pop("token", None)
# download the checkpoints and configs
cached_folder = cls.download(
pretrained_model_name_or_path,
variant=variant,
token=token,
**kwargs,
)
# pop out "_ignore_files" as it is only needed for download
config_dict = cls.load_config(cached_folder)
config_dict.pop("_ignore_files", None)
# Define which model components should load variants.
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors` with variant being `"fp16"`.
model_variants = {}
if variant is not None:
for folder in os.listdir(cached_folder):
folder_path = os.path.join(cached_folder, folder)
is_folder = os.path.isdir(folder_path) and folder in config_dict
variant_exists = is_folder and any(
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
)
if variant_exists:
model_variants[folder] = variant
# load the pipeline class
pipeline_class = _get_pipeline_class(cls, config=config_dict)
# define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`)
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {
k: init_dict.pop(k)
for k in optional_kwargs
if k in init_dict and k not in pipeline_class._optional_components
}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# load each module in the pipeline
pipelines = import_module("diffusers.pipelines")
tqdm = logging.tqdm if progress is None else progress.tqdm
for name, (library_name, class_name) in tqdm(
sorted(init_dict.items()),
desc="Loading pipeline components",
):
# use passed sub model or load class_name from library_name
loaded_sub_model = None
if name in passed_class_obj:
# passed as an argument like "scheduler"
loaded_sub_model = passed_class_obj[name]
else:
loaded_sub_model = load_sub_model(
library_name=library_name,
class_name=class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=pipelines,
is_pipeline_module=hasattr(pipelines, library_name),
pipeline_class=pipeline_class,
torch_dtype=torch_dtype,
provider=None,
sess_options=None,
device_map=None,
max_memory=None,
offload_folder=None,
offload_state_dict=False,
model_variants=model_variants,
name=name,
from_flax=False,
variant=variant,
low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT,
cached_folder=cached_folder,
)
init_kwargs[name] = loaded_sub_model
# potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
if len(missing_modules) > 0:
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
# instantiate the pipeline
model = pipeline_class(**init_kwargs)
# save where the model was instantiated from
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
return model
# Copied from https://github.com/huggingface/diffusers/blob/v0.30.3/src/diffusers/loaders/single_file.py#L270
@classmethod
@validate_hf_hub_args
def from_single_file(cls, pretrained_model_link_or_path, progress=None, **kwargs):
token = kwargs.pop("token", None)
torch_dtype = kwargs.pop("torch_dtype", None)
# load the pipeline class
pipeline_class = _get_pipeline_class(cls, config=None)
checkpoint = load_single_file_checkpoint(pretrained_model_link_or_path, token=token)
config = fetch_diffusers_config(checkpoint)
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
# attempt to download the config files for the pipeline
cached_model_config_path = snapshot_download(
default_pretrained_model_config_name,
token=token,
allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"],
)
# pop out "_ignore_files" as it is only needed for download
config_dict = pipeline_class.load_config(cached_model_config_path)
config_dict.pop("_ignore_files", None)
# define expected modules given pipeline signature and define non-None initialized modules (=`init_kwargs`)
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
return False
return True
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# load each module in the pipeline
pipelines = import_module("diffusers.pipelines")
tqdm = logging.tqdm if progress is None else progress.tqdm
for name, (library_name, class_name) in tqdm(
sorted(init_dict.items()),
desc="Loading pipeline components",
):
# use passed sub model or load class_name from library_name
loaded_sub_model = None
if name in passed_class_obj:
# passed as an argument like "scheduler"
loaded_sub_model = passed_class_obj[name]
else:
loaded_sub_model = load_single_file_sub_model(
library_name=library_name,
class_name=class_name,
name=name,
checkpoint=checkpoint,
is_pipeline_module=hasattr(pipelines, library_name),
cached_model_config_path=cached_model_config_path,
pipelines=pipelines,
torch_dtype=torch_dtype,
**kwargs,
)
init_kwargs[name] = loaded_sub_model
# potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
if len(missing_modules) > 0:
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
# instantiate the pipeline
pipe = pipeline_class(**init_kwargs)
# save where the model was instantiated from
pipe.register_to_config(_name_or_path=pretrained_model_link_or_path)
return pipe
class CustomStableDiffusionPipeline(CustomDiffusionMixin, StableDiffusionPipeline):
pass
class CustomStableDiffusionImg2ImgPipeline(CustomDiffusionMixin, StableDiffusionImg2ImgPipeline):
pass