Spaces:
Running
Running
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline | |
from PIL import Image | |
import numpy as np | |
import os | |
from huggingface_hub import hf_hub_download | |
import warnings | |
from transformers import CLIPProcessor, CLIPModel | |
warnings.filterwarnings("ignore") | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load CLIP model for semantic guidance | |
print("Loading CLIP model for semantic guidance...") | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device) | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
# Dictionary of available concepts | |
CONCEPTS = { | |
"canna-lily-flowers102": { | |
"repo_id": "sd-concepts-library/canna-lily-flowers102", | |
"type": "object", | |
"description": "Canna lily flower style" | |
}, | |
"samurai-jack": { | |
"repo_id": "sd-concepts-library/samurai-jack", | |
"type": "style", | |
"description": "Samurai Jack animation style" | |
}, | |
"babies-poster": { | |
"repo_id": "sd-concepts-library/babies-poster", | |
"type": "style", | |
"description": "Babies poster art style" | |
}, | |
"animal-toy": { | |
"repo_id": "sd-concepts-library/animal-toy", | |
"type": "object", | |
"description": "Animal toy style" | |
}, | |
"sword-lily-flowers102": { | |
"repo_id": "sd-concepts-library/sword-lily-flowers102", | |
"type": "object", | |
"description": "Sword lily flower style" | |
} | |
} | |
def car_loss(image): | |
"""Custom loss function that encourages the presence of cars in the image""" | |
# Convert PIL image to tensor if needed | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = torch.tensor(image, device=device) | |
# Process image for CLIP | |
with torch.no_grad(): | |
# Convert to PIL for CLIP processing | |
pil_image = Image.fromarray(image.cpu().numpy().astype(np.uint8)) | |
# Get CLIP features for the image | |
inputs = clip_processor( | |
text=["a photo of a car", "a photo without cars"], | |
images=pil_image, | |
return_tensors="pt", | |
padding=True | |
).to(device) | |
# Get similarity scores | |
outputs = clip_model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
# Higher score for the first text (with cars) is better | |
car_score = logits_per_image[0][0] | |
no_car_score = logits_per_image[0][1] | |
# We want to maximize car_score and minimize no_car_score | |
loss = -(car_score - no_car_score) | |
return loss | |
def generate_image(pipe, prompt, seed, guidance_scale=7.5, num_inference_steps=30, use_car_guidance=False): | |
"""Generate an image with optional car guidance""" | |
generator = torch.Generator(device).manual_seed(seed) | |
custom_loss = car_loss if use_car_guidance else None | |
if custom_loss: | |
try: | |
# Start with a standard generation | |
init_images = pipe( | |
prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps // 2, | |
generator=generator | |
).images | |
init_image = init_images[0] | |
# Refine using car guidance | |
from diffusers import StableDiffusionImg2ImgPipeline | |
img2img_pipe = StableDiffusionImg2ImgPipeline( | |
vae=pipe.vae, | |
text_encoder=pipe.text_encoder, | |
tokenizer=pipe.tokenizer, | |
unet=pipe.unet, | |
scheduler=pipe.scheduler, | |
safety_checker=None, | |
feature_extractor=None, | |
).to(device) | |
strength = 0.75 | |
current_image = init_image | |
for i in range(5): | |
current_loss = custom_loss(current_image) | |
refined_images = img2img_pipe( | |
prompt=prompt + ", with beautiful cars", | |
image=current_image, | |
strength=strength, | |
guidance_scale=guidance_scale, | |
generator=generator, | |
).images | |
current_image = refined_images[0] | |
strength *= 0.8 | |
return current_image | |
except Exception as e: | |
print(f"Error in car-guided generation: {e}") | |
return pipe( | |
prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator | |
).images[0] | |
else: | |
return pipe( | |
prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
generator=generator | |
).images[0] | |
# Cache for loaded models and concepts | |
loaded_models = {} | |
def get_model_with_concept(concept_name): | |
"""Get or load a model with the specified concept""" | |
if concept_name not in loaded_models: | |
concept_info = CONCEPTS[concept_name] | |
# Download concept embedding | |
concept_path = f"concepts/{concept_name}.bin" | |
os.makedirs("concepts", exist_ok=True) | |
if not os.path.exists(concept_path): | |
file = hf_hub_download( | |
repo_id=concept_info["repo_id"], | |
filename="learned_embeds.bin", | |
repo_type="model" | |
) | |
import shutil | |
shutil.copy(file, concept_path) | |
# Load model and concept | |
pipe = StableDiffusionPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2", | |
torch_dtype=torch.float32 if device == "cpu" else torch.float16, | |
safety_checker=None | |
).to(device) | |
pipe.load_textual_inversion(concept_path) | |
loaded_models[concept_name] = pipe | |
return loaded_models[concept_name] | |
def generate_images(concept_name, base_prompt, seed, use_car_guidance): | |
"""Generate images using the selected concept""" | |
try: | |
# Get model with concept | |
pipe = get_model_with_concept(concept_name) | |
# Construct prompt based on concept type | |
if CONCEPTS[concept_name]["type"] == "object": | |
prompt = f"A {base_prompt} with a <{concept_name}>" | |
else: | |
prompt = f"<{concept_name}> {base_prompt}" | |
# Generate image | |
image = generate_image( | |
pipe=pipe, | |
prompt=prompt, | |
seed=int(seed), | |
use_car_guidance=use_car_guidance | |
) | |
return image | |
except Exception as e: | |
raise gr.Error(f"Error generating image: {str(e)}") | |
# Create Gradio interface | |
with gr.Blocks(title="Stable Diffusion Style Explorer") as demo: | |
gr.Markdown(""" | |
# Stable Diffusion Style Explorer | |
Generate images using various concepts from the SD Concepts Library, with optional car guidance. | |
## How to use: | |
1. Select a concept from the dropdown | |
2. Enter a base prompt (or use the default) | |
3. Set a seed for reproducibility | |
4. Choose whether to use car guidance | |
5. Click Generate! | |
Check out the examples below to see different combinations of concepts and prompts! | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
concept = gr.Dropdown( | |
choices=list(CONCEPTS.keys()), | |
value="samurai-jack", | |
label="Select Concept" | |
) | |
prompt = gr.Textbox( | |
value="A serene landscape with mountains and a lake at sunset", | |
label="Base Prompt" | |
) | |
seed = gr.Number( | |
value=42, | |
label="Seed", | |
precision=0 | |
) | |
car_guidance = gr.Checkbox( | |
value=False, | |
label="Use Car Guidance" | |
) | |
generate_btn = gr.Button("Generate Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Generated Image") | |
concept.change( | |
fn=lambda x: gr.Markdown(f"Selected concept: {CONCEPTS[x]['description']} ({CONCEPTS[x]['type']})"), | |
inputs=[concept], | |
outputs=[gr.Markdown()] | |
) | |
generate_btn.click( | |
fn=generate_images, | |
inputs=[concept, prompt, seed, car_guidance], | |
outputs=[output_image] | |
) | |
# Gallery of pre-generated examples | |
gr.Markdown("### 🖼️ Pre-generated Examples") | |
with gr.Row(): | |
# Samurai Jack examples | |
with gr.Column(): | |
gr.Markdown("**Samurai Jack Style**") | |
gr.Image("samurai-jack_normal.png", | |
label="Without Car Guidance") | |
gr.Image("samurai-jack_car.png", | |
label="With Car Guidance") | |
with gr.Row(): | |
# Canna Lily examples | |
with gr.Column(): | |
gr.Markdown("**Canna Lily Object**") | |
gr.Image("canna-lily-flowers102_normal.png", | |
label="Without Car Guidance") | |
gr.Image("canna-lily-flowers102_car.png", | |
label="With Car Guidance") | |
with gr.Row(): | |
# Babies Poster examples | |
with gr.Column(): | |
gr.Markdown("**Babies Poster Style**") | |
gr.Image("babies-poster_normal.png", | |
label="Without Car Guidance") | |
gr.Image("babies-poster_car.png", | |
label="With Car Guidance") | |
with gr.Row(): | |
# Animal Toy examples | |
with gr.Column(): | |
gr.Markdown("**Animal Toy Object**") | |
gr.Image("animal-toy_normal.png", | |
label="Without Car Guidance") | |
gr.Image("animal-toy_car.png", | |
label="With Car Guidance") | |
with gr.Row(): | |
# Sword Lily examples | |
with gr.Column(): | |
gr.Markdown("**Sword Lily Object**") | |
gr.Image("sword-lily-flowers102_normal.png", | |
label="Without Car Guidance") | |
gr.Image("sword-lily-flowers102_car.png", | |
label="With Car Guidance") | |
demo.launch() |