StableDiffusion / app.py
jatingocodeo's picture
Update app.py
e21143a verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import numpy as np
import os
from huggingface_hub import hf_hub_download
import warnings
from transformers import CLIPProcessor, CLIPModel
warnings.filterwarnings("ignore")
# Check if CUDA is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load CLIP model for semantic guidance
print("Loading CLIP model for semantic guidance...")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# Dictionary of available concepts
CONCEPTS = {
"canna-lily-flowers102": {
"repo_id": "sd-concepts-library/canna-lily-flowers102",
"type": "object",
"description": "Canna lily flower style"
},
"samurai-jack": {
"repo_id": "sd-concepts-library/samurai-jack",
"type": "style",
"description": "Samurai Jack animation style"
},
"babies-poster": {
"repo_id": "sd-concepts-library/babies-poster",
"type": "style",
"description": "Babies poster art style"
},
"animal-toy": {
"repo_id": "sd-concepts-library/animal-toy",
"type": "object",
"description": "Animal toy style"
},
"sword-lily-flowers102": {
"repo_id": "sd-concepts-library/sword-lily-flowers102",
"type": "object",
"description": "Sword lily flower style"
}
}
def car_loss(image):
"""Custom loss function that encourages the presence of cars in the image"""
# Convert PIL image to tensor if needed
if isinstance(image, Image.Image):
image = np.array(image)
image = torch.tensor(image, device=device)
# Process image for CLIP
with torch.no_grad():
# Convert to PIL for CLIP processing
pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
# Get CLIP features for the image
inputs = clip_processor(
text=["a photo of a car", "a photo without cars"],
images=pil_image,
return_tensors="pt",
padding=True
).to(device)
# Get similarity scores
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
# Higher score for the first text (with cars) is better
car_score = logits_per_image[0][0]
no_car_score = logits_per_image[0][1]
# We want to maximize car_score and minimize no_car_score
loss = -(car_score - no_car_score)
return loss
def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False):
"""Generate an image with optional car guidance"""
generator = torch.Generator(device).manual_seed(seed)
custom_loss = car_loss if use_car_guidance else None
if custom_loss:
try:
# Start with a standard generation
init_images = pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps // 2,
generator=generator
).images
init_image = init_images[0]
# Refine using car guidance
from diffusers import StableDiffusionImg2ImgPipeline
img2img_pipe = StableDiffusionImg2ImgPipeline(
vae=pipe.vae,
text_encoder=pipe.text_encoder,
tokenizer=pipe.tokenizer,
unet=pipe.unet,
scheduler=pipe.scheduler,
safety_checker=None,
feature_extractor=None,
).to(device)
strength = 0.75
current_image = init_image
for i in range(5):
current_loss = custom_loss(current_image)
refined_images = img2img_pipe(
prompt=prompt + ", with beautiful cars",
image=current_image,
strength=strength,
guidance_scale=guidance_scale,
generator=generator,
).images
current_image = refined_images[0]
strength *= 0.8
return current_image
except Exception as e:
print(f"Error in car-guided generation: {e}")
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
else:
return pipe(
prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator
).images[0]
# Cache for loaded models and concepts
loaded_models = {}
def get_model_with_concept(concept_name):
"""Get or load a model with the specified concept"""
if concept_name not in loaded_models:
concept_info = CONCEPTS[concept_name]
# Download concept embedding
concept_path = f"concepts/{concept_name}.bin"
os.makedirs("concepts", exist_ok=True)
if not os.path.exists(concept_path):
file = hf_hub_download(
repo_id=concept_info["repo_id"],
filename="learned_embeds.bin",
repo_type="model"
)
import shutil
shutil.copy(file, concept_path)
# Load model and concept
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2",
torch_dtype=torch.float32 if device == "cpu" else torch.float16,
safety_checker=None
).to(device)
pipe.load_textual_inversion(concept_path)
loaded_models[concept_name] = pipe
return loaded_models[concept_name]
def generate_images(concept_name, base_prompt, seed, use_car_guidance):
"""Generate images using the selected concept"""
try:
# Get model with concept
pipe = get_model_with_concept(concept_name)
# Construct prompt based on concept type
if CONCEPTS[concept_name]["type"] == "object":
prompt = f"A {base_prompt} with a <{concept_name}>"
else:
prompt = f"<{concept_name}> {base_prompt}"
# Generate image
image = generate_image(
pipe=pipe,
prompt=prompt,
seed=int(seed),
use_car_guidance=use_car_guidance
)
return image
except Exception as e:
raise gr.Error(f"Error generating image: {str(e)}")
# Create Gradio interface
with gr.Blocks(title="Stable Diffusion Style Explorer") as demo:
gr.Markdown("""
# Stable Diffusion Style Explorer
Generate images using various concepts from the SD Concepts Library, with optional car guidance.
## How to use:
1. Select a concept from the dropdown
2. Enter a base prompt (or use the default)
3. Set a seed for reproducibility
4. Choose whether to use car guidance
5. Click Generate!
Check out the examples below to see different combinations of concepts and prompts!
""")
with gr.Row():
with gr.Column():
concept = gr.Dropdown(
choices=list(CONCEPTS.keys()),
value="samurai-jack",
label="Select Concept"
)
prompt = gr.Textbox(
value="A serene landscape with mountains and a lake at sunset",
label="Base Prompt"
)
seed = gr.Number(
value=42,
label="Seed",
precision=0
)
car_guidance = gr.Checkbox(
value=False,
label="Use Car Guidance"
)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image")
concept.change(
fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"),
inputs=[concept],
outputs=[gr.Markdown()]
)
generate_btn.click(
fn=generate_images,
inputs=[concept, prompt, seed, car_guidance],
outputs=[output_image]
)
# Gallery of pre-generated examples
gr.Markdown("### 🖼️ Pre-generated Examples")
with gr.Row():
# Samurai Jack examples
with gr.Column():
gr.Markdown("**Samurai Jack Style**")
gr.Image("samurai-jack_normal.png",
label="Without Car Guidance")
gr.Image("samurai-jack_car.png",
label="With Car Guidance")
with gr.Row():
# Canna Lily examples
with gr.Column():
gr.Markdown("**Canna Lily Object**")
gr.Image("canna-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("canna-lily-flowers102_car.png",
label="With Car Guidance")
with gr.Row():
# Babies Poster examples
with gr.Column():
gr.Markdown("**Babies Poster Style**")
gr.Image("babies-poster_normal.png",
label="Without Car Guidance")
gr.Image("babies-poster_car.png",
label="With Car Guidance")
with gr.Row():
# Animal Toy examples
with gr.Column():
gr.Markdown("**Animal Toy Object**")
gr.Image("animal-toy_normal.png",
label="Without Car Guidance")
gr.Image("animal-toy_car.png",
label="With Car Guidance")
with gr.Row():
# Sword Lily examples
with gr.Column():
gr.Markdown("**Sword Lily Object**")
gr.Image("sword-lily-flowers102_normal.png",
label="Without Car Guidance")
gr.Image("sword-lily-flowers102_car.png",
label="With Car Guidance")
demo.launch()