#!/usr/bin/env python3 """ Utility functions for the application Author: Shilpaj Bhalerao Date: Feb 26, 2025 """ import torch import gc import os import sys from PIL import Image, ImageDraw, ImageFont # Disable HF transfer to avoid download issues os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" # Create a monkey patch for the cached_download function # This is needed because newer versions of huggingface_hub # removed cached_download but diffusers still tries to import it def apply_huggingface_patch(): import importlib import huggingface_hub # Check if cached_download is already available if hasattr(huggingface_hub, 'cached_download'): return # No need to patch # Create a wrapper around hf_hub_download to mimic the old cached_download def cached_download(*args, **kwargs): # Forward to the new function with appropriate args return huggingface_hub.hf_hub_download(*args, **kwargs) # Add the function to the huggingface_hub module setattr(huggingface_hub, 'cached_download', cached_download) # Make sure diffusers.utils.dynamic_modules_utils sees the patched module if 'diffusers.utils.dynamic_modules_utils' in sys.modules: del sys.modules['diffusers.utils.dynamic_modules_utils'] def load_models(device="cuda"): """ Load the necessary models for stable diffusion :param device: (str) Device to load models on ('cuda', 'mps', or 'cpu') :return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe) """ # Apply the patch before importing diffusers apply_huggingface_patch() # Now we can safely import from diffusers from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline from transformers import CLIPTokenizer, CLIPTextModel # Set device if device == "cuda" and not torch.cuda.is_available(): device = "mps" if torch.backends.mps.is_available() else "cpu" if device == "mps": os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" print(f"Loading models on {device}...") # Load the autoencoder model which will be used to decode the latents into image space vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False) # Load the tokenizer and text encoder to tokenize and encode the text tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") # The UNet model for generating the latents unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False) # The noise scheduler scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) # Load the full pipeline for concept loading pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, use_safetensors=False ) # Move models to device vae = vae.to(device) text_encoder = text_encoder.to(device) unet = unet.to(device) pipe = pipe.to(device) return vae, tokenizer, text_encoder, unet, scheduler, pipe def clear_gpu_memory(): """ Clear GPU memory cache """ torch.cuda.empty_cache() gc.collect() def set_timesteps(scheduler, num_inference_steps): """ Set timesteps for the scheduler with MPS compatibility fix :param scheduler: (Scheduler) Scheduler to set timesteps for :param num_inference_steps: (int) Number of inference steps """ scheduler.set_timesteps(num_inference_steps) scheduler.timesteps = scheduler.timesteps.to(torch.float32) def pil_to_latent(input_im, vae, device): """ Convert the image to latents :param input_im: (PIL.Image) Input PIL image :param vae: (VAE) VAE model :param device: (str) Device to run on :return: (torch.Tensor) Latents from VAE's encoder """ from torchvision import transforms as tfms # Single image -> single latent in a batch (so size 1, 4, 64, 64) with torch.no_grad(): latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling return 0.18215 * latent.latent_dist.sample() def latents_to_pil(latents, vae): """ Convert the latents to images :param latents: (torch.Tensor) Latent tensor :param vae: (VAE) VAE model :return: (list) PIL images """ # batch of latents -> list of images latents = (1 / 0.18215) * latents with torch.no_grad(): image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).numpy() images = (image * 255).round().astype("uint8") pil_images = [Image.fromarray(image) for image in images] return pil_images def image_grid(imgs, rows, cols, labels=None): """ Create a grid of images with optional labels. :param imgs: (list) List of PIL images to be arranged in a grid :param rows: (int) Number of rows in the grid :param cols: (int) Number of columns in the grid :param labels: (list, optional) List of label strings for each image :return: (PIL.Image) A single image with all input images arranged in a grid and labeled """ assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})" w, h = imgs[0].size grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h)) # Add padding at the bottom for labels if they exist label_height = 30 if labels else 0 # Paste images for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h)) # Add labels if provided if labels: assert len(labels) == len(imgs), "Number of labels must match number of images" draw = ImageDraw.Draw(grid) # Try to use a standard font, fall back to default if not available try: font = ImageFont.truetype("arial.ttf", 14) except IOError: font = ImageFont.load_default() for i, label in enumerate(labels): # Position text under the image x = (i % cols) * w + 10 y = (i // cols + 1) * h - 5 # Draw black text with white outline for visibility # White outline (draw text in each direction) for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]: draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font) # Main text (black) draw.text((x, y), label, fill=(0,0,0), font=font) return grid def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]): """ Creates a strong vignette effect (dark corners) and color shift. :param images: (torch.Tensor) Batch of images from VAE decoder (range 0-1) :param vignette_strength: (float) How strong the darkening effect is (higher = more dramatic) :param color_shift: (list) RGB color to shift the center toward [r, g, b] :return: (torch.Tensor) Loss value """ batch_size, channels, height, width = images.shape # Create coordinate grid centered at 0 with range [-1, 1] y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device) x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device) # Calculate radius from center (normalized [0,1]) radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414 # Vignette mask: dark at edges, bright in center vignette = torch.exp(-vignette_strength * radius) # Color shift target: shift center toward specified color color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device) center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0) center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic # Target image with vignette and color shift target = images.clone() # Apply vignette (multiply all channels by vignette mask) for c in range(channels): target[:, c] = target[:, c] * vignette # Apply color shift in center for c in range(channels): # Shift toward target color more in center, less at edges color_offset = (color_tensor[:, c] - images[:, c]) * center_mask target[:, c] = target[:, c] + color_offset.squeeze(1) # Calculate loss - how different current image is from our target return torch.pow(images - target, 2).mean() def get_concept_embedding(concept_text, tokenizer, text_encoder, device): """ Generate CLIP embedding for a concept described in text :param concept_text: (str) Text description of the concept (e.g., "sketch painting") :param tokenizer: (CLIPTokenizer) CLIP tokenizer :param text_encoder: (CLIPTextModel) CLIP text encoder :param device: (str) Device to run on :return: (torch.Tensor) CLIP embedding for the concept """ # Tokenize the concept text concept_tokens = tokenizer( concept_text, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" ).input_ids.to(device) # Generate the embedding using the text encoder with torch.no_grad(): concept_embedding = text_encoder(concept_tokens)[0] return concept_embedding