Avijit Ghosh
add login
df3866a
raw
history blame
7.36 kB
import gradio as gr
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusion3Pipeline
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces
# Define model initialization functions
def load_model(model_name, token=None):
if model_name == "stabilityai/sdxl-turbo":
pipeline = DiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
elif model_name == "runwayml/stable-diffusion-v1-5":
pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
).to("cuda")
elif model_name == "ByteDance/SDXL-Lightning":
base = "stabilityai/stable-diffusion-xl-base-1.0"
ckpt = "sdxl_lightning_4step_unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
pipeline = StableDiffusionXLPipeline.from_pretrained(
base,
unet=unet,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
elif model_name == "segmind/SSD-1B":
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
).to("cuda")
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
if token is None:
raise ValueError("Hugging Face token is required to access this model")
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_auth_token=token
).to("cuda")
else:
raise ValueError("Unknown model name")
return pipeline
# Initialize the default model
default_model = "stabilityai/sdxl-turbo"
pipeline_text2image = load_model(default_model)
@spaces.GPU
def getimgen(prompt, model_name):
global pipeline_text2image
pipeline_text2image = load_model(model_name)
if model_name == "stabilityai/sdxl-turbo":
return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
elif model_name == "runwayml/stable-diffusion-v1-5":
return pipeline_text2image(prompt).images[0]
elif model_name == "ByteDance/SDXL-Lightning":
return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
elif model_name == "segmind/SSD-1B":
neg_prompt = "ugly, blurry, poor quality"
return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@spaces.GPU
def blip_caption_image(image, prefix):
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
return blip_processor.decode(out[0], skip_special_tokens=True)
def genderfromcaption(caption):
cc = caption.split()
if "man" in cc or "boy" in cc:
return "Man"
elif "woman" in cc or "girl" in cc:
return "Woman"
return "Unsure"
def genderplot(genlist):
order = ["Man", "Woman", "Unsure"]
words = sorted(genlist, key=lambda x: order.index(x))
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def skintoneplot(hex_codes):
hex_codes = [code for code in hex_codes if code is not None]
rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
if i < len(sorted_hex_codes):
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
return fig
@spaces.GPU
def generate_images_plots(prompt, model_name, token=None):
global pipeline_text2image
pipeline_text2image = load_model(model_name, token)
foldername = "temp"
Path(foldername).mkdir(parents=True, exist_ok=True)
images = [getimgen(prompt, model_name) for _ in range(10)]
genders = []
skintones = []
for image, i in zip(images, range(10)):
prompt_prefix = "photo of a "
caption = blip_caption_image(image, prefix=prompt_prefix)
image.save(f"{foldername}/image_{i}.png")
try:
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
skintones.append(tone)
except:
skintones.append(None)
genders.append(genderfromcaption(caption))
return images, skintoneplot(skintones), genderplot(genders)
with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
model_dropdown = gr.Dropdown(
label="Choose a model",
choices=[
"stabilityai/sdxl-turbo",
"runwayml/stable-diffusion-v1-5",
"ByteDance/SDXL-Lightning",
"segmind/SSD-1B",
"stabilityai/stable-diffusion-3-medium-diffusers"
],
value=default_model
)
prompt = gr.Textbox(label="Enter the Prompt")
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[2],
object_fit="contain",
height="auto"
)
gr.LoginButton()
token = gr.OAuthToken()
gr.Markdown('### You need to log in to your Hugging Face account to run Stable Diffusion 3')
btn = gr.Button("Generate images", scale=0)
with gr.Row(equal_height=True):
skinplot = gr.Plot(label="Skin Tone")
genplot = gr.Plot(label="Gender")
btn.click(generate_images_plots, inputs=[prompt, model_dropdown, token], outputs=[gallery, skinplot, genplot])
demo.launch(debug=True)