import json import re import time from contextlib import contextmanager from datetime import datetime from itertools import product from os import environ from types import MethodType from typing import Callable import spaces import tomesd import torch from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType from compel.prompt_parser import PromptParser from DeepCache import DeepCacheSDHelper from diffusers import ( DEISMultistepScheduler, DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, StableDiffusionPipeline, ) from diffusers.models import AutoencoderKL, AutoencoderTiny from tgate.SD import tgate as tgate_sd from tgate.SD_DeepCache import tgate as tgate_sd_deepcache from torch._dynamo import OptimizedModule # some models use the deprecated CLIPFeatureExtractor class (should use CLIPImageProcessor) __import__("warnings").filterwarnings("ignore", category=FutureWarning, module="transformers") __import__("transformers").logging.set_verbosity_error() ZERO_GPU = ( environ.get("SPACES_ZERO_GPU", "").lower() == "true" or environ.get("SPACES_ZERO_GPU", "") == "1" ) EMBEDDINGS = { "./embeddings/bad_prompt_version2.pt": "", "./embeddings/BadDream.pt": "", "./embeddings/FastNegativeV2.pt": "", "./embeddings/negative_hand.pt": "", "./embeddings/UnrealisticDream.pt": "", } with open("./styles/twri.json") as f: styles = json.load(f) # inspired by ComfyUI # https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/model_management.py class Loader: _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(Loader, cls).__new__(cls) cls._instance.cpu = torch.device("cpu") cls._instance.gpu = torch.device("cuda") cls._instance.pipe = None return cls._instance def _load_deepcache(self, interval=1): has_deepcache = hasattr(self.pipe, "deepcache") if has_deepcache and self.pipe.deepcache.params["cache_interval"] == interval: return self.pipe.deepcache if has_deepcache: self.pipe.deepcache.disable() else: self.pipe.deepcache = DeepCacheSDHelper(pipe=self.pipe) self.pipe.deepcache.set_params(cache_interval=interval) self.pipe.deepcache.enable() return self.pipe.deepcache def _load_tgate(self): has_tgate = hasattr(self.pipe, "tgate") has_deepcache = hasattr(self.pipe, "deepcache") if not has_tgate: self.pipe.tgate = MethodType( tgate_sd_deepcache if has_deepcache else tgate_sd, self.pipe, ) return self.pipe.tgate def _load_vae(self, model_name=None, taesd=False, dtype=None): vae_type = type(self.pipe.vae) is_kl = issubclass(vae_type, (AutoencoderKL, OptimizedModule)) is_tiny = issubclass(vae_type, AutoencoderTiny) # by default all models use KL if is_kl and taesd: # can't compile tiny VAE print("Switching to Tiny VAE...") self.pipe.vae = AutoencoderTiny.from_pretrained( pretrained_model_name_or_path="madebyollin/taesd", use_safetensors=True, torch_dtype=dtype, ).to(self.gpu) return self.pipe.vae if is_tiny and not taesd: print("Switching to KL VAE...") self.pipe.vae = torch.compile( fullgraph=True, mode="reduce-overhead", model=AutoencoderKL.from_pretrained( pretrained_model_name_or_path=model_name, use_safetensors=True, torch_dtype=dtype, subfolder="vae", ).to(self.gpu), ) return self.pipe.vae def load(self, model, scheduler, karras, taesd, deepcache_interval, dtype=None): model_lower = model.lower() schedulers = { "DEIS 2M": DEISMultistepScheduler, "DPM++ 2M": DPMSolverMultistepScheduler, "DPM2 a": KDPM2AncestralDiscreteScheduler, "Euler a": EulerAncestralDiscreteScheduler, "Heun": HeunDiscreteScheduler, "LMS": LMSDiscreteScheduler, "PNDM": PNDMScheduler, } scheduler_kwargs = { "beta_schedule": "scaled_linear", "timestep_spacing": "leading", "use_karras_sigmas": karras, "beta_start": 0.00085, "beta_end": 0.012, "steps_offset": 1, } if scheduler == "PNDM" or scheduler == "Euler a": del scheduler_kwargs["use_karras_sigmas"] pipe_kwargs = { "scheduler": schedulers[scheduler](**scheduler_kwargs), "pretrained_model_name_or_path": model_lower, "requires_safety_checker": False, "use_safetensors": True, "safety_checker": None, "torch_dtype": dtype, } # already loaded if self.pipe is not None: model_name = self.pipe.config._name_or_path same_model = model_name.lower() == model_lower same_scheduler = isinstance(self.pipe.scheduler, schedulers[scheduler]) same_karras = ( not hasattr(self.pipe.scheduler.config, "use_karras_sigmas") or self.pipe.scheduler.config.use_karras_sigmas == karras ) if same_model: if not same_scheduler: print(f"Switching to {scheduler}...") if not same_karras: print(f"{'Enabling' if karras else 'Disabling'} Karras sigmas...") if not same_scheduler or not same_karras: self.pipe.scheduler = schedulers[scheduler](**scheduler_kwargs) self._load_vae(model_lower, taesd, dtype) self._load_deepcache(interval=deepcache_interval) self._load_tgate() return self.pipe else: print(f"Unloading {model_name.lower()}...") self.pipe = None torch.cuda.empty_cache() # no fp16 available if not ZERO_GPU and model_lower not in [ "sg161222/realistic_vision_v5.1_novae", "prompthero/openjourney-v4", "linaqruf/anything-v3-1", ]: pipe_kwargs["variant"] = "fp16" print(f"Loading {model_lower} with {'Tiny' if taesd else 'KL'} VAE...") self.pipe = StableDiffusionPipeline.from_pretrained(**pipe_kwargs).to(self.gpu) self._load_vae(model_lower, taesd, dtype) self._load_deepcache(interval=deepcache_interval) self._load_tgate() self.pipe.load_textual_inversion( pretrained_model_name_or_path=list(EMBEDDINGS.keys()), tokens=list(EMBEDDINGS.values()), ) return self.pipe # applies tome to the pipeline @contextmanager def token_merging(pipe, tome_ratio=0): try: if tome_ratio > 0: tomesd.apply_patch(pipe, max_downsample=1, sx=2, sy=2, ratio=tome_ratio) yield finally: tomesd.remove_patch(pipe) # idempotent # parse prompts with arrays def parse_prompt(prompt: str) -> list[str]: arrays = re.findall(r"\[\[(.*?)\]\]", prompt) if not arrays: return [prompt] tokens = [item.split(",") for item in arrays] combinations = list(product(*tokens)) prompts = [] for combo in combinations: current_prompt = prompt for i, token in enumerate(combo): current_prompt = current_prompt.replace(f"[[{arrays[i]}]]", token.strip(), 1) prompts.append(current_prompt) return prompts def apply_style(prompt, style_name, negative=False): global styles if not style_name or style_name == "None": return prompt for style in styles: if style["name"] == style_name: if negative: return prompt + " . " + style["negative_prompt"] else: return style["prompt"].format(prompt=prompt) return prompt # 1024x1024 for 50 steps can take ~10s each @spaces.GPU(duration=44) def generate( positive_prompt, negative_prompt="", style=None, seed=None, model="runwayml/stable-diffusion-v1-5", scheduler="PNDM", width=512, height=512, guidance_scale=7.5, inference_steps=50, num_images=1, karras=False, taesd=False, clip_skip=False, truncate_prompts=False, increment_seed=True, deepcache_interval=1, tgate_step=0, tome_ratio=0, log: Callable[[str], None] = None, Error=Exception, ): if not torch.cuda.is_available(): raise Error("CUDA not available") # https://pytorch.org/docs/stable/generated/torch.manual_seed.html if seed is None or seed < 0: seed = int(datetime.now().timestamp() * 1_000_000) % (2**64) TORCH_DTYPE = ( torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 ) EMBEDDINGS_TYPE = ( ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NORMALIZED if clip_skip else ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED ) with torch.inference_mode(): start = time.perf_counter() loader = Loader() pipe = loader.load(model, scheduler, karras, taesd, deepcache_interval, TORCH_DTYPE) # prompt embeds compel = Compel( textual_inversion_manager=DiffusersTextualInversionManager(pipe), dtype_for_device_getter=lambda _: TORCH_DTYPE, returned_embeddings_type=EMBEDDINGS_TYPE, truncate_long_prompts=truncate_prompts, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, device=pipe.device, ) images = [] current_seed = seed try: styled_negative_prompt = apply_style(negative_prompt, style, negative=True) neg_embeds = compel(styled_negative_prompt) except PromptParser.ParsingException: raise Error("ParsingException: Invalid negative prompt") for i in range(num_images): # seeded generator for each iteration generator = torch.Generator(device=pipe.device).manual_seed(current_seed) try: all_positive_prompts = parse_prompt(positive_prompt) prompt_index = i % len(all_positive_prompts) pos_prompt = all_positive_prompts[prompt_index] styled_pos_prompt = apply_style(pos_prompt, style) pos_embeds = compel(styled_pos_prompt) pos_embeds, neg_embeds = compel.pad_conditioning_tensors_to_same_length( [pos_embeds, neg_embeds] ) except PromptParser.ParsingException: raise Error("ParsingException: Invalid prompt") with token_merging(pipe, tome_ratio=tome_ratio): # cap the tgate step gate_step = min( tgate_step if tgate_step > 0 else inference_steps, inference_steps, ) result = pipe.tgate( num_inference_steps=inference_steps, negative_prompt_embeds=neg_embeds, guidance_scale=guidance_scale, prompt_embeds=pos_embeds, gate_step=gate_step, generator=generator, height=height, width=width, ) images.append((result.images[0], str(current_seed))) if increment_seed: current_seed += 1 if ZERO_GPU: # spaces always start fresh loader.pipe = None end = time.perf_counter() diff = end - start if log: log(f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {diff:.2f}s") return images