StableDiffusion / utils.py
Shilpaj's picture
Upload utils.py
8c6e31a verified
#!/usr/bin/env python3
"""
Utility functions for the application
Author: Shilpaj Bhalerao
Date: Feb 26, 2025
"""
import torch
import gc
import os
import sys
from PIL import Image, ImageDraw, ImageFont
# Disable HF transfer to avoid download issues
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
# Create a monkey patch for the cached_download function
# This is needed because newer versions of huggingface_hub
# removed cached_download but diffusers still tries to import it
def apply_huggingface_patch():
import importlib
import huggingface_hub
# Check if cached_download is already available
if hasattr(huggingface_hub, 'cached_download'):
return # No need to patch
# Create a wrapper around hf_hub_download to mimic the old cached_download
def cached_download(*args, **kwargs):
# Forward to the new function with appropriate args
return huggingface_hub.hf_hub_download(*args, **kwargs)
# Add the function to the huggingface_hub module
setattr(huggingface_hub, 'cached_download', cached_download)
# Make sure diffusers.utils.dynamic_modules_utils sees the patched module
if 'diffusers.utils.dynamic_modules_utils' in sys.modules:
del sys.modules['diffusers.utils.dynamic_modules_utils']
def load_models(device="cuda"):
"""
Load the necessary models for stable diffusion
:param device: (str) Device to load models on ('cuda', 'mps', or 'cpu')
:return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe)
"""
# Apply the patch before importing diffusers
apply_huggingface_patch()
# Now we can safely import from diffusers
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline
from transformers import CLIPTokenizer, CLIPTextModel
# Set device
if device == "cuda" and not torch.cuda.is_available():
device = "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
print(f"Loading models on {device}...")
# Load the autoencoder model which will be used to decode the latents into image space
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False)
# Load the tokenizer and text encoder to tokenize and encode the text
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
# The UNet model for generating the latents
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False)
# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
# Load the full pipeline for concept loading
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
use_safetensors=False
)
# Move models to device
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)
pipe = pipe.to(device)
return vae, tokenizer, text_encoder, unet, scheduler, pipe
def clear_gpu_memory():
"""
Clear GPU memory cache
"""
torch.cuda.empty_cache()
gc.collect()
def set_timesteps(scheduler, num_inference_steps):
"""
Set timesteps for the scheduler with MPS compatibility fix
:param scheduler: (Scheduler) Scheduler to set timesteps for
:param num_inference_steps: (int) Number of inference steps
"""
scheduler.set_timesteps(num_inference_steps)
scheduler.timesteps = scheduler.timesteps.to(torch.float32)
def pil_to_latent(input_im, vae, device):
"""
Convert the image to latents
:param input_im: (PIL.Image) Input PIL image
:param vae: (VAE) VAE model
:param device: (str) Device to run on
:return: (torch.Tensor) Latents from VAE's encoder
"""
from torchvision import transforms as tfms
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
return 0.18215 * latent.latent_dist.sample()
def latents_to_pil(latents, vae):
"""
Convert the latents to images
:param latents: (torch.Tensor) Latent tensor
:param vae: (VAE) VAE model
:return: (list) PIL images
"""
# batch of latents -> list of images
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def image_grid(imgs, rows, cols, labels=None):
"""
Create a grid of images with optional labels.
:param imgs: (list) List of PIL images to be arranged in a grid
:param rows: (int) Number of rows in the grid
:param cols: (int) Number of columns in the grid
:param labels: (list, optional) List of label strings for each image
:return: (PIL.Image) A single image with all input images arranged in a grid and labeled
"""
assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h))
# Add padding at the bottom for labels if they exist
label_height = 30 if labels else 0
# Paste images
for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
# Add labels if provided
if labels:
assert len(labels) == len(imgs), "Number of labels must match number of images"
draw = ImageDraw.Draw(grid)
# Try to use a standard font, fall back to default if not available
try:
font = ImageFont.truetype("arial.ttf", 14)
except IOError:
font = ImageFont.load_default()
for i, label in enumerate(labels):
# Position text under the image
x = (i % cols) * w + 10
y = (i // cols + 1) * h - 5
# Draw black text with white outline for visibility
# White outline (draw text in each direction)
for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]:
draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font)
# Main text (black)
draw.text((x, y), label, fill=(0,0,0), font=font)
return grid
def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
"""
Creates a strong vignette effect (dark corners) and color shift.
:param images: (torch.Tensor) Batch of images from VAE decoder (range 0-1)
:param vignette_strength: (float) How strong the darkening effect is (higher = more dramatic)
:param color_shift: (list) RGB color to shift the center toward [r, g, b]
:return: (torch.Tensor) Loss value
"""
batch_size, channels, height, width = images.shape
# Create coordinate grid centered at 0 with range [-1, 1]
y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device)
x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device)
# Calculate radius from center (normalized [0,1])
radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414
# Vignette mask: dark at edges, bright in center
vignette = torch.exp(-vignette_strength * radius)
# Color shift target: shift center toward specified color
color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device)
center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0)
center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic
# Target image with vignette and color shift
target = images.clone()
# Apply vignette (multiply all channels by vignette mask)
for c in range(channels):
target[:, c] = target[:, c] * vignette
# Apply color shift in center
for c in range(channels):
# Shift toward target color more in center, less at edges
color_offset = (color_tensor[:, c] - images[:, c]) * center_mask
target[:, c] = target[:, c] + color_offset.squeeze(1)
# Calculate loss - how different current image is from our target
return torch.pow(images - target, 2).mean()
def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
"""
Generate CLIP embedding for a concept described in text
:param concept_text: (str) Text description of the concept (e.g., "sketch painting")
:param tokenizer: (CLIPTokenizer) CLIP tokenizer
:param text_encoder: (CLIPTextModel) CLIP text encoder
:param device: (str) Device to run on
:return: (torch.Tensor) CLIP embedding for the concept
"""
# Tokenize the concept text
concept_tokens = tokenizer(
concept_text,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt"
).input_ids.to(device)
# Generate the embedding using the text encoder
with torch.no_grad():
concept_embedding = text_encoder(concept_tokens)[0]
return concept_embedding