Spaces:
Runtime error
Runtime error
import json | |
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
import torch | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from diffusers.loaders import FromSingleFileMixin | |
from diffusers.utils import ( | |
USE_PEFT_BACKEND, | |
deprecate, | |
logging, | |
) | |
from diffusers.utils.torch_utils import randn_tensor | |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline | |
from diffusers.pipelines.pipeline_utils import * | |
from diffusers.pipelines.pipeline_utils import _get_pipeline_class | |
from diffusers.models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT | |
from diffusers_patch.models.unet_2d_condition_woct import UNet2DConditionWoCTModel | |
from diffusers_patch.pipelines.oms.utils import SDXLTextEncoder, SDXLTokenizer | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
def load_sub_model_oms( | |
library_name: str, | |
class_name: str, | |
importable_classes: List[Any], | |
pipelines: Any, | |
is_pipeline_module: bool, | |
pipeline_class: Any, | |
torch_dtype: torch.dtype, | |
provider: Any, | |
sess_options: Any, | |
device_map: Optional[Union[Dict[str, torch.device], str]], | |
max_memory: Optional[Dict[Union[int, str], Union[int, str]]], | |
offload_folder: Optional[Union[str, os.PathLike]], | |
offload_state_dict: bool, | |
model_variants: Dict[str, str], | |
name: str, | |
from_flax: bool, | |
variant: str, | |
low_cpu_mem_usage: bool, | |
cached_folder: Union[str, os.PathLike], | |
): | |
"""Helper method to load the module `name` from `library_name` and `class_name`""" | |
# retrieve class candidates | |
class_obj, class_candidates = get_class_obj_and_candidates( | |
library_name, | |
class_name, | |
importable_classes, | |
pipelines, | |
is_pipeline_module, | |
component_name=name, | |
cache_dir=cached_folder, | |
) | |
load_method_name = None | |
# retrive load method name | |
for class_name, class_candidate in class_candidates.items(): | |
if class_candidate is not None and issubclass(class_obj, class_candidate): | |
load_method_name = importable_classes[class_name][1] | |
# if load method name is None, then we have a dummy module -> raise Error | |
if load_method_name is None: | |
none_module = class_obj.__module__ | |
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( | |
TRANSFORMERS_DUMMY_MODULES_FOLDER | |
) | |
if is_dummy_path and "dummy" in none_module: | |
# call class_obj for nice error message of missing requirements | |
class_obj() | |
raise ValueError( | |
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" | |
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." | |
) | |
load_method = getattr(class_obj, load_method_name) | |
# add kwargs to loading method | |
import diffusers | |
loading_kwargs = {} | |
if issubclass(class_obj, torch.nn.Module): | |
loading_kwargs["torch_dtype"] = torch_dtype | |
if issubclass(class_obj, diffusers.OnnxRuntimeModel): | |
loading_kwargs["provider"] = provider | |
loading_kwargs["sess_options"] = sess_options | |
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin) | |
if is_transformers_available(): | |
transformers_version = version.parse(version.parse(transformers.__version__).base_version) | |
else: | |
transformers_version = "N/A" | |
is_transformers_model = ( | |
is_transformers_available() | |
and issubclass(class_obj, PreTrainedModel) | |
and transformers_version >= version.parse("4.20.0") | |
) | |
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. | |
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. | |
# This makes sure that the weights won't be initialized which significantly speeds up loading. | |
if is_diffusers_model or is_transformers_model: | |
loading_kwargs["device_map"] = device_map | |
loading_kwargs["max_memory"] = max_memory | |
loading_kwargs["offload_folder"] = offload_folder | |
loading_kwargs["offload_state_dict"] = offload_state_dict | |
loading_kwargs["variant"] = model_variants.pop(name, None) | |
if from_flax: | |
loading_kwargs["from_flax"] = True | |
# the following can be deleted once the minimum required `transformers` version | |
# is higher than 4.27 | |
if ( | |
is_transformers_model | |
and loading_kwargs["variant"] is not None | |
and transformers_version < version.parse("4.27.0") | |
): | |
raise ImportError( | |
f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" | |
) | |
elif is_transformers_model and loading_kwargs["variant"] is None: | |
loading_kwargs.pop("variant") | |
# if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` | |
if not (from_flax and is_transformers_model): | |
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | |
else: | |
loading_kwargs["low_cpu_mem_usage"] = False | |
# check if oms directory | |
if 'oms' in name: | |
config_name = os.path.join(cached_folder, name, 'config.json') | |
with open(config_name, "r", encoding="utf-8") as f: | |
index = json.load(f) | |
file_path_or_name = index['_name_or_path'] | |
if 'SDXL' in index.get('_class_name', 'CLIP'): | |
loaded_sub_model = load_method(file_path_or_name, **loading_kwargs) | |
elif 'subfolder' in index.keys(): | |
loading_kwargs["subfolder"] = index["subfolder"] | |
loaded_sub_model = load_method(file_path_or_name, **loading_kwargs) | |
else: | |
# check if the module is in a subdirectory | |
if os.path.isdir(os.path.join(cached_folder, name)): | |
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) | |
else: | |
# else load from the root directory | |
loaded_sub_model = load_method(cached_folder, **loading_kwargs) | |
return loaded_sub_model | |
class OMSPipeline(DiffusionPipeline, FromSingleFileMixin): | |
def __init__( | |
self, | |
oms_module: UNet2DConditionWoCTModel, | |
sd_pipeline: DiffusionPipeline, | |
oms_text_encoder:Optional[Union[CLIPTextModel, SDXLTextEncoder]], | |
oms_tokenizer:Optional[Union[CLIPTokenizer, SDXLTokenizer]], | |
sd_scheduler = None | |
): | |
# assert sd_pipeline is not None | |
if oms_tokenizer is None: | |
oms_tokenizer = sd_pipeline.tokenizer | |
if oms_text_encoder is None: | |
oms_text_encoder = sd_pipeline.text_encoder | |
# For OMS with SDXL text encoders | |
if 'SDXL' in oms_text_encoder.__class__.__name__: | |
self.is_dual_text_encoder = True | |
else: | |
self.is_dual_text_encoder = False | |
self.register_modules( | |
oms_module=oms_module, | |
oms_text_encoder=oms_text_encoder, | |
oms_tokenizer=oms_tokenizer, | |
sd_pipeline = sd_pipeline | |
) | |
if sd_scheduler is None: | |
self.scheduler = sd_pipeline.scheduler | |
else: | |
self.scheduler = sd_scheduler | |
sd_pipeline.scheduler = sd_scheduler | |
self.vae_scale_factor = 2 ** (len(sd_pipeline.vae.config.block_out_channels) - 1) | |
self.default_sample_size = sd_pipeline.unet.config.sample_size | |
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents | |
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): | |
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) | |
if isinstance(generator, list) and len(generator) != batch_size: | |
raise ValueError( | |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
) | |
if latents is None: | |
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
else: | |
latents = latents.to(device) | |
# scale the initial noise by the standard deviation required by the scheduler | |
latents = latents * self.scheduler.init_noise_sigma | |
return latents | |
def oms_step(self, predict_v, latents, do_classifier_free_guidance_for_oms, oms_guidance_scale, generator, alpha_prod_t_prev): | |
if do_classifier_free_guidance_for_oms: | |
pred_uncond, pred_text = predict_v.chunk(2) | |
predict_v = pred_uncond + oms_guidance_scale * (pred_text - pred_uncond) | |
# so fking dirty but keep it for now | |
alpha_prod_t = torch.zeros_like(alpha_prod_t_prev) | |
beta_prod_t = 1 - alpha_prod_t | |
beta_prod_t_prev = 1 - alpha_prod_t_prev | |
current_alpha_t = alpha_prod_t / alpha_prod_t_prev | |
current_beta_t = 1 - current_alpha_t | |
pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * predict_v | |
# pred_original_sample = - predict_v | |
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t | |
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t | |
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents | |
pred_prev_sample = pred_prev_sample | |
# TODO unit variance but seem dont need it | |
device = latents.device | |
variance_noise = randn_tensor( | |
latents.shape, generator=generator, device=device, dtype=latents.dtype | |
) | |
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t | |
variance = torch.clamp(variance, min=1e-20) * variance_noise | |
latents = pred_prev_sample + variance | |
return latents | |
def oms_text_encode(self, prompt, num_images_per_prompt, device): | |
max_length = None if self.is_dual_text_encoder else self.oms_tokenizer.model_max_length | |
if self.is_dual_text_encoder: | |
tokenized_prompts = self.oms_tokenizer(prompt, | |
padding='max_length', | |
max_length=max_length, | |
truncation=True, | |
return_tensors='pt').input_ids | |
tokenized_prompts = torch.stack([tokenized_prompts[0], tokenized_prompts[1]], dim=1) | |
text_embeddings, _ = self.oms_text_encoder( [tokenized_prompts[:, 0, :].to(device), tokenized_prompts[:, 1, :].to(device)]) # type: ignore | |
elif 'clip' in self.oms_text_encoder.config_class.model_type: | |
tokenized_prompts = self.oms_tokenizer(prompt, | |
padding='max_length', | |
max_length=max_length, | |
truncation=True, | |
return_tensors='pt').input_ids | |
text_embeddings = self.oms_text_encoder(tokenized_prompts.to(device))[0] # type: ignore | |
else: # T5 | |
tokenized_prompts = self.oms_tokenizer(prompt, | |
padding='max_length', | |
max_length=max_length, | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors='pt').input_ids | |
# Note: t5 text encoder outputs "None" under fp16 | |
with torch.cuda.amp.autocast(dtype=torch.float32): | |
text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] | |
# duplicate text embeddings for each generation per prompt | |
bs_embed, seq_len, _ = text_embeddings.shape | |
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) # type: ignore | |
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) | |
return text_embeddings | |
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): | |
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | |
resume_download = kwargs.pop("resume_download", False) | |
force_download = kwargs.pop("force_download", False) | |
proxies = kwargs.pop("proxies", None) | |
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
revision = kwargs.pop("revision", None) | |
from_flax = kwargs.pop("from_flax", False) | |
torch_dtype = kwargs.pop("torch_dtype", None) | |
custom_pipeline = kwargs.pop("custom_pipeline", None) | |
custom_revision = kwargs.pop("custom_revision", None) | |
provider = kwargs.pop("provider", None) | |
sess_options = kwargs.pop("sess_options", None) | |
device_map = kwargs.pop("device_map", None) | |
max_memory = kwargs.pop("max_memory", None) | |
offload_folder = kwargs.pop("offload_folder", None) | |
offload_state_dict = kwargs.pop("offload_state_dict", False) | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | |
variant = kwargs.pop("variant", None) | |
use_safetensors = kwargs.pop("use_safetensors", None) | |
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) | |
# 1. Download the checkpoints and configs | |
# use snapshot download here to get it working from from_pretrained | |
if not os.path.isdir(pretrained_model_name_or_path): | |
if pretrained_model_name_or_path.count("/") > 1: | |
raise ValueError( | |
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"' | |
" is neither a valid local path nor a valid repo id. Please check the parameter." | |
) | |
cached_folder = cls.download( | |
pretrained_model_name_or_path, | |
cache_dir=cache_dir, | |
resume_download=resume_download, | |
force_download=force_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
from_flax=from_flax, | |
use_safetensors=use_safetensors, | |
custom_pipeline=custom_pipeline, | |
custom_revision=custom_revision, | |
variant=variant, | |
load_connected_pipeline=load_connected_pipeline, | |
**kwargs, | |
) | |
else: | |
cached_folder = pretrained_model_name_or_path | |
config_dict = cls.load_config(cached_folder) | |
# pop out "_ignore_files" as it is only needed for download | |
config_dict.pop("_ignore_files", None) | |
# 2. Define which model components should load variants | |
# We retrieve the information by matching whether variant | |
# model checkpoints exist in the subfolders | |
model_variants = {} | |
if variant is not None: | |
for folder in os.listdir(cached_folder): | |
folder_path = os.path.join(cached_folder, folder) | |
is_folder = os.path.isdir(folder_path) and folder in config_dict | |
variant_exists = is_folder and any( | |
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path) | |
) | |
if variant_exists: | |
model_variants[folder] = variant | |
# 3. Load the pipeline class, if using custom module then load it from the hub | |
# if we load from explicit class, let's use it | |
pipeline_class = _get_pipeline_class( | |
cls, | |
config_dict, | |
load_connected_pipeline=load_connected_pipeline, | |
custom_pipeline=custom_pipeline, | |
cache_dir=cache_dir, | |
revision=custom_revision, | |
) | |
# DEPRECATED: To be removed in 1.0.0 | |
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse( | |
version.parse(config_dict["_diffusers_version"]).base_version | |
) <= version.parse("0.5.1"): | |
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy | |
pipeline_class = StableDiffusionInpaintPipelineLegacy | |
deprecation_message = ( | |
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the" | |
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For" | |
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting" | |
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your" | |
f" checkpoint {pretrained_model_name_or_path} to the format of" | |
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain" | |
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0." | |
) | |
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False) | |
# 4. Define expected modules given pipeline signature | |
# and define non-None initialized modules (=`init_kwargs`) | |
# some modules can be passed directly to the init | |
# in this case they are already instantiated in `kwargs` | |
# extract them here | |
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) | |
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} | |
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs} | |
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) | |
# define init kwargs and make sure that optional component modules are filtered out | |
init_kwargs = { | |
k: init_dict.pop(k) | |
for k in optional_kwargs | |
if k in init_dict and k not in pipeline_class._optional_components | |
} | |
init_kwargs = {**init_kwargs, **passed_pipe_kwargs} | |
# remove `null` components | |
def load_module(name, value): | |
if value[0] is None: | |
return False | |
if name in passed_class_obj and passed_class_obj[name] is None: | |
return False | |
return True | |
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)} | |
# Special case: safety_checker must be loaded separately when using `from_flax` | |
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj: | |
raise NotImplementedError( | |
"The safety checker cannot be automatically loaded when loading weights `from_flax`." | |
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker" | |
" separately if you need it." | |
) | |
# 5. Throw nice warnings / errors for fast accelerate loading | |
if len(unused_kwargs) > 0: | |
logger.warning( | |
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored." | |
) | |
if low_cpu_mem_usage and not is_accelerate_available(): | |
low_cpu_mem_usage = False | |
logger.warning( | |
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" | |
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" | |
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" | |
" install accelerate\n```\n." | |
) | |
if device_map is not None and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `device_map=None`." | |
) | |
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `low_cpu_mem_usage=False`." | |
) | |
if low_cpu_mem_usage is False and device_map is not None: | |
raise ValueError( | |
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and" | |
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." | |
) | |
# import it here to avoid circular import | |
from diffusers import pipelines | |
# 6. Load each module in the pipeline | |
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."): | |
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names | |
class_name = class_name[4:] if class_name.startswith("Flax") else class_name | |
# 6.2 Define all importable classes | |
is_pipeline_module = hasattr(pipelines, library_name) | |
importable_classes = ALL_IMPORTABLE_CLASSES | |
loaded_sub_model = None | |
# 6.3 Use passed sub model or load class_name from library_name | |
if name in passed_class_obj: | |
# if the model is in a pipeline module, then we load it from the pipeline | |
# check that passed_class_obj has correct parent class | |
maybe_raise_or_warn( | |
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module | |
) | |
loaded_sub_model = passed_class_obj[name] | |
else: | |
# load sub model | |
loaded_sub_model = load_sub_model_oms( | |
library_name=library_name, | |
class_name=class_name, | |
importable_classes=importable_classes, | |
pipelines=pipelines, | |
is_pipeline_module=is_pipeline_module, | |
pipeline_class=pipeline_class, | |
torch_dtype=torch_dtype, | |
provider=provider, | |
sess_options=sess_options, | |
device_map=device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
model_variants=model_variants, | |
name=name, | |
from_flax=from_flax, | |
variant=variant, | |
low_cpu_mem_usage=low_cpu_mem_usage, | |
cached_folder=cached_folder, | |
) | |
logger.info( | |
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." | |
) | |
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) | |
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")): | |
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) | |
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS} | |
load_kwargs = { | |
"cache_dir": cache_dir, | |
"resume_download": resume_download, | |
"force_download": force_download, | |
"proxies": proxies, | |
"local_files_only": local_files_only, | |
"use_auth_token": use_auth_token, | |
"revision": revision, | |
"torch_dtype": torch_dtype, | |
"custom_pipeline": custom_pipeline, | |
"custom_revision": custom_revision, | |
"provider": provider, | |
"sess_options": sess_options, | |
"device_map": device_map, | |
"max_memory": max_memory, | |
"offload_folder": offload_folder, | |
"offload_state_dict": offload_state_dict, | |
"low_cpu_mem_usage": low_cpu_mem_usage, | |
"variant": variant, | |
"use_safetensors": use_safetensors, | |
} | |
connected_pipes = { | |
prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy()) | |
for prefix, repo_id in connected_pipes.items() | |
if repo_id is not None | |
} | |
for prefix, connected_pipe in connected_pipes.items(): | |
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder" | |
init_kwargs.update( | |
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()} | |
) | |
# 7. Potentially add passed objects if expected | |
missing_modules = set(expected_modules) - set(init_kwargs.keys()) | |
passed_modules = list(passed_class_obj.keys()) | |
optional_modules = pipeline_class._optional_components | |
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules): | |
for module in missing_modules: | |
init_kwargs[module] = passed_class_obj.get(module, None) | |
elif len(missing_modules) > 0: | |
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs | |
raise ValueError( | |
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed." | |
) | |
# 8. Instantiate the pipeline | |
model = pipeline_class(**init_kwargs) | |
# 9. Save where the model was instantiated from | |
model.register_to_config(_name_or_path=pretrained_model_name_or_path) | |
return model | |
# @replace_example_docstring(EXAMPLE_DOC_STRING) | |
def __call__( | |
self, | |
prompt: Union[str, List[str]] = None, | |
oms_prompt: Union[str, List[str]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 50, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
oms_guidance_scale: float = 1.0, | |
oms_flag: bool = True, | |
**kwargs, | |
): | |
"""Pseudo-doc for OMS""" | |
if oms_flag is True: | |
if oms_prompt is not None: | |
sd_prompt = prompt | |
prompt = oms_prompt | |
if prompt is not None and isinstance(prompt, str): | |
batch_size = 1 | |
elif prompt is not None and isinstance(prompt, list): | |
batch_size = len(prompt) | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
device = self._execution_device | |
## Guidance flag for OMS | |
if oms_guidance_scale is not None: | |
do_classifier_free_guidance_for_oms = True | |
else: | |
do_classifier_free_guidance_for_oms = False | |
oms_prompt_emb = self.oms_text_encode(prompt,num_images_per_prompt,device) | |
if do_classifier_free_guidance_for_oms: | |
oms_negative_prompt = ([''] * (batch_size // num_images_per_prompt)) | |
oms_negative_prompt_emb = self.oms_text_encode(oms_negative_prompt,num_images_per_prompt,device) | |
# 4. Prepare timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps | |
# 5. Prepare latent variables | |
num_channels_latents = self.oms_module.config.in_channels | |
latents = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
oms_prompt_emb.dtype, | |
device, | |
generator, | |
latents=None, | |
) | |
## OMS CFG | |
if do_classifier_free_guidance_for_oms: | |
oms_prompt_emb = torch.cat([oms_negative_prompt_emb, oms_prompt_emb], dim=0) | |
## OMS to device | |
oms_prompt_emb = oms_prompt_emb.to(device) | |
## Perform OMS | |
alphas_cumprod = self.scheduler.alphas_cumprod.to(device) | |
alpha_prod_t_prev = alphas_cumprod[int(timesteps[0].item())] | |
latent_input_oms = torch.cat([latents] * 2) if do_classifier_free_guidance_for_oms else latents | |
v_pred_oms = self.oms_module(latent_input_oms, oms_prompt_emb)['sample'] | |
latents = self.oms_step(v_pred_oms, latents, do_classifier_free_guidance_for_oms, oms_guidance_scale, generator, alpha_prod_t_prev) | |
if oms_prompt is not None: | |
prompt = sd_prompt | |
print('OMS Completed') | |
else: | |
print("OMS unloaded") | |
latents = None | |
output = self.sd_pipeline( | |
prompt = prompt, | |
height = height, | |
width = width, | |
num_inference_steps = num_inference_steps, | |
num_images_per_prompt = num_images_per_prompt, | |
generator = generator, | |
latents = latents, | |
**kwargs | |
) | |
return output |