Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Utility functions for the application | |
Author: Shilpaj Bhalerao | |
Date: Feb 26, 2025 | |
""" | |
import torch | |
import gc | |
import os | |
import sys | |
from PIL import Image, ImageDraw, ImageFont | |
# Disable HF transfer to avoid download issues | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
# Create a monkey patch for the cached_download function | |
# This is needed because newer versions of huggingface_hub | |
# removed cached_download but diffusers still tries to import it | |
def apply_huggingface_patch(): | |
import importlib | |
import huggingface_hub | |
# Check if cached_download is already available | |
if hasattr(huggingface_hub, 'cached_download'): | |
return # No need to patch | |
# Create a wrapper around hf_hub_download to mimic the old cached_download | |
def cached_download(*args, **kwargs): | |
# Forward to the new function with appropriate args | |
return huggingface_hub.hf_hub_download(*args, **kwargs) | |
# Add the function to the huggingface_hub module | |
setattr(huggingface_hub, 'cached_download', cached_download) | |
# Make sure diffusers.utils.dynamic_modules_utils sees the patched module | |
if 'diffusers.utils.dynamic_modules_utils' in sys.modules: | |
del sys.modules['diffusers.utils.dynamic_modules_utils'] | |
def load_models(device="cuda"): | |
""" | |
Load the necessary models for stable diffusion | |
:param device: (str) Device to load models on ('cuda', 'mps', or 'cpu') | |
:return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe) | |
""" | |
# Apply the patch before importing diffusers | |
apply_huggingface_patch() | |
# Now we can safely import from diffusers | |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel, StableDiffusionPipeline | |
from transformers import CLIPTokenizer, CLIPTextModel | |
# Set device | |
if device == "cuda" and not torch.cuda.is_available(): | |
device = "mps" if torch.backends.mps.is_available() else "cpu" | |
if device == "mps": | |
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
print(f"Loading models on {device}...") | |
# Load the autoencoder model which will be used to decode the latents into image space | |
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=False) | |
# Load the tokenizer and text encoder to tokenize and encode the text | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") | |
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") | |
# The UNet model for generating the latents | |
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=False) | |
# The noise scheduler | |
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) | |
# Load the full pipeline for concept loading | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"runwayml/stable-diffusion-v1-5", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_safetensors=False | |
) | |
# Move models to device | |
vae = vae.to(device) | |
text_encoder = text_encoder.to(device) | |
unet = unet.to(device) | |
pipe = pipe.to(device) | |
return vae, tokenizer, text_encoder, unet, scheduler, pipe | |
def clear_gpu_memory(): | |
""" | |
Clear GPU memory cache | |
""" | |
torch.cuda.empty_cache() | |
gc.collect() | |
def set_timesteps(scheduler, num_inference_steps): | |
""" | |
Set timesteps for the scheduler with MPS compatibility fix | |
:param scheduler: (Scheduler) Scheduler to set timesteps for | |
:param num_inference_steps: (int) Number of inference steps | |
""" | |
scheduler.set_timesteps(num_inference_steps) | |
scheduler.timesteps = scheduler.timesteps.to(torch.float32) | |
def pil_to_latent(input_im, vae, device): | |
""" | |
Convert the image to latents | |
:param input_im: (PIL.Image) Input PIL image | |
:param vae: (VAE) VAE model | |
:param device: (str) Device to run on | |
:return: (torch.Tensor) Latents from VAE's encoder | |
""" | |
from torchvision import transforms as tfms | |
# Single image -> single latent in a batch (so size 1, 4, 64, 64) | |
with torch.no_grad(): | |
latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling | |
return 0.18215 * latent.latent_dist.sample() | |
def latents_to_pil(latents, vae): | |
""" | |
Convert the latents to images | |
:param latents: (torch.Tensor) Latent tensor | |
:param vae: (VAE) VAE model | |
:return: (list) PIL images | |
""" | |
# batch of latents -> list of images | |
latents = (1 / 0.18215) * latents | |
with torch.no_grad(): | |
image = vae.decode(latents).sample | |
image = (image / 2 + 0.5).clamp(0, 1) | |
image = image.detach().cpu().permute(0, 2, 3, 1).numpy() | |
images = (image * 255).round().astype("uint8") | |
pil_images = [Image.fromarray(image) for image in images] | |
return pil_images | |
def image_grid(imgs, rows, cols, labels=None): | |
""" | |
Create a grid of images with optional labels. | |
:param imgs: (list) List of PIL images to be arranged in a grid | |
:param rows: (int) Number of rows in the grid | |
:param cols: (int) Number of columns in the grid | |
:param labels: (list, optional) List of label strings for each image | |
:return: (PIL.Image) A single image with all input images arranged in a grid and labeled | |
""" | |
assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})" | |
w, h = imgs[0].size | |
grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h)) | |
# Add padding at the bottom for labels if they exist | |
label_height = 30 if labels else 0 | |
# Paste images | |
for i, img in enumerate(imgs): | |
grid.paste(img, box=(i%cols*w, i//cols*h)) | |
# Add labels if provided | |
if labels: | |
assert len(labels) == len(imgs), "Number of labels must match number of images" | |
draw = ImageDraw.Draw(grid) | |
# Try to use a standard font, fall back to default if not available | |
try: | |
font = ImageFont.truetype("arial.ttf", 14) | |
except IOError: | |
font = ImageFont.load_default() | |
for i, label in enumerate(labels): | |
# Position text under the image | |
x = (i % cols) * w + 10 | |
y = (i // cols + 1) * h - 5 | |
# Draw black text with white outline for visibility | |
# White outline (draw text in each direction) | |
for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]: | |
draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font) | |
# Main text (black) | |
draw.text((x, y), label, fill=(0,0,0), font=font) | |
return grid | |
def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]): | |
""" | |
Creates a strong vignette effect (dark corners) and color shift. | |
:param images: (torch.Tensor) Batch of images from VAE decoder (range 0-1) | |
:param vignette_strength: (float) How strong the darkening effect is (higher = more dramatic) | |
:param color_shift: (list) RGB color to shift the center toward [r, g, b] | |
:return: (torch.Tensor) Loss value | |
""" | |
batch_size, channels, height, width = images.shape | |
# Create coordinate grid centered at 0 with range [-1, 1] | |
y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device) | |
x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device) | |
# Calculate radius from center (normalized [0,1]) | |
radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414 | |
# Vignette mask: dark at edges, bright in center | |
vignette = torch.exp(-vignette_strength * radius) | |
# Color shift target: shift center toward specified color | |
color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device) | |
center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0) | |
center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic | |
# Target image with vignette and color shift | |
target = images.clone() | |
# Apply vignette (multiply all channels by vignette mask) | |
for c in range(channels): | |
target[:, c] = target[:, c] * vignette | |
# Apply color shift in center | |
for c in range(channels): | |
# Shift toward target color more in center, less at edges | |
color_offset = (color_tensor[:, c] - images[:, c]) * center_mask | |
target[:, c] = target[:, c] + color_offset.squeeze(1) | |
# Calculate loss - how different current image is from our target | |
return torch.pow(images - target, 2).mean() | |
def get_concept_embedding(concept_text, tokenizer, text_encoder, device): | |
""" | |
Generate CLIP embedding for a concept described in text | |
:param concept_text: (str) Text description of the concept (e.g., "sketch painting") | |
:param tokenizer: (CLIPTokenizer) CLIP tokenizer | |
:param text_encoder: (CLIPTextModel) CLIP text encoder | |
:param device: (str) Device to run on | |
:return: (torch.Tensor) CLIP embedding for the concept | |
""" | |
# Tokenize the concept text | |
concept_tokens = tokenizer( | |
concept_text, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt" | |
).input_ids.to(device) | |
# Generate the embedding using the text encoder | |
with torch.no_grad(): | |
concept_embedding = text_encoder(concept_tokens)[0] | |
return concept_embedding | |