#!/usr/bin/env python3 """ Gradio Application for Stable Diffusion Author: Shilpaj Bhalerao Date: Feb 26, 2025 """ import gc import os import torch import gradio as gr # import spaces from tqdm.auto import tqdm from PIL import Image from utils import ( load_models, clear_gpu_memory, set_timesteps, latents_to_pil, vignette_loss, get_concept_embedding, image_grid ) # Remove this import to avoid the cached_download error # from diffusers import StableDiffusionPipeline def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width): """ Function to generate latents from the UNet :param seed_number: Seed :param prompt: Text prompt :param concept: Concept to influence generation (optional) :param concept_strength: How strongly to apply the concept (0.0-1.0) :return: Latents of the UNet. This will be passed to the VAE to generate the image """ global art_concepts # Batch size batch_size = 1 # Set the seed generator = torch.manual_seed(seed) # Prep text text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.no_grad(): text_embeddings = text_encoder(text_input.input_ids.to(device))[0] # Get the concept embedding concept_embedding = art_concepts[concept] # Apply concept embedding influence if provided if concept_embedding is not None and concept_strength > 0: # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3: # Add batch dimension to concept_embedding to match text_embeddings concept_embedding = concept_embedding.unsqueeze(0) # Create weighted blend between original text embedding and concept if text_embeddings.shape == concept_embedding.shape: # Interpolate between text embeddings and concept text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding print(f"Successfully applied concept with strength {concept_strength}") else: print(f"Warning: Shapes still incompatible after adjustment. Concept: {concept_embedding.shape}, Text: {text_embeddings.shape}") # And the uncond. input as before: max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.no_grad(): uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Prep Scheduler set_timesteps(scheduler, num_inference_steps) # Prep latents latents = torch.randn( (batch_size, unet.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(device) latents = latents * scheduler.init_noise_sigma # Loop for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. latent_model_input = torch.cat([latents] * 2) sigma = scheduler.sigmas[i] latent_model_input = scheduler.scale_model_input(latent_model_input, t) # predict the noise residual with torch.no_grad(): noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # perform CFG noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) #### ADDITIONAL GUIDANCE ### if i%5 == 0: # Requires grad on the latents latents = latents.detach().requires_grad_() # Get the predicted x0: latents_x0 = latents - sigma * noise_pred # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample # Decode to image space denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1) # Calculate loss loss = vignette_loss(denoised_images) * vignette_loss_scale # Occasionally print it out if i%10==0: print(i, 'loss:', loss.item()) # Get gradient cond_grad = torch.autograd.grad(loss, latents)[0] # Modify the latents based on this gradient latents = latents.detach() - cond_grad * sigma**2 # Now step with scheduler latents = scheduler.step(noise_pred, t, latents).prev_sample return latents def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, vignette_loss_scale=0.0, concept="none", concept_strength=0.5, height=512, width=512): """ Generate a single image """ global vae latents = generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) generated_image = latents_to_pil(latents, vae) return image_grid(generated_image, 1, 1, None) def generate_style_images(prompt, num_inference_steps=30, guidance_scale=7.5, vignette_loss_scale=0.0, concept_strength=0.5, height=512, width=512): """ Function to generate images of all the styles """ global art_concepts, vae seed_list = [2000, 1000, 500, 600, 100] latents_collect = [] concept_labels = [] # Load and remove the "none" element concepts_list = list(art_concepts.keys()) concepts_list.remove("none") for seed_no, concept in zip(seed_list, concepts_list): # Clear the CUDA cache torch.cuda.empty_cache() gc.collect() torch.cuda.empty_cache() print(f"Generating image with concept '{concept}' at strength {concept_strength}") # Generate latents using the concept embedding latents = generate_latents(prompt, seed_no, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width) latents_collect.append(latents) concept_labels.append(f"{concept} ({concept_strength})") # Show results latents_collect = torch.vstack(latents_collect) images = latents_to_pil(latents_collect, vae) return image_grid(images, 1, len(seed_list), concept_labels) # Define Gradio interface # @spaces.GPU(enable_queue=False) def create_demo(): with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo: gr.Markdown("# Guided Stable Diffusion with Styles") with gr.Tab("Single Image Generation"): with gr.Row(): with gr.Column(): all_styles = ["none"] + list(art_concepts.keys()) all_styles.remove("none") # Remove "none" to avoid duplication all_styles = ["none"] + all_styles # Add it back at the beginning prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair") seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=1000) concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none") concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) height = gr.Slider(minimum=256, maximum=1024, step=1, label="Height", value=512) width = gr.Slider(minimum=256, maximum=1024, step=1, label="Width", value=512) guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) generate_btn = gr.Button("Generate Image") with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") with gr.Tab("Style Grid"): with gr.Row(): with gr.Column(): grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park") grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0) grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0) grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) grid_generate_btn = gr.Button("Generate Style Grid") with gr.Column(): output_grid = gr.Image(label="Style Grid", type="pil") # Set up event handlers generate_btn.click( generate_image, inputs=[prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept_style, concept_strength, height, width], outputs=output_image ) grid_generate_btn.click( generate_style_images, inputs=[grid_prompt, grid_num_inference_steps, grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength], outputs=output_grid ) return demo # Launch the app if __name__ == "__main__": # Set device device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" if device == "mps": os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" # Load models vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device=device) # Define art style concepts art_concepts = { "sketch_painting": get_concept_embedding("a sketch painting, pencil drawing, hand-drawn illustration", tokenizer, text_encoder, device), "oil_painting": get_concept_embedding("an oil painting, textured canvas, painterly technique", tokenizer, text_encoder, device), "watercolor": get_concept_embedding("a watercolor painting, fluid, soft edges", tokenizer, text_encoder, device), "digital_art": get_concept_embedding("digital art, computer generated, precise details", tokenizer, text_encoder, device), "comic_book": get_concept_embedding("comic book style, ink outlines, cel shading", tokenizer, text_encoder, device), "none": None } demo = create_demo() demo.launch(debug=True)