lucianosb's picture
Adds Age Classifier and NSFW Classifier
8765391 verified
raw
history blame
13.1 kB
import gradio as gr
import torch
from diffusers import (
DiffusionPipeline,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
EulerDiscreteScheduler,
UNet2DConditionModel,
StableDiffusion3Pipeline
)
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces
access_token = os.getenv("AccessTokenSD3")
from huggingface_hub import login
login(token = access_token)
# Define model initialization functions
def load_model(model_name):
if model_name == "stabilityai/sdxl-turbo":
pipeline = DiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
elif model_name == "runwayml/stable-diffusion-v1-5":
pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
).to("cuda")
elif model_name == "ByteDance/SDXL-Lightning":
base = "stabilityai/stable-diffusion-xl-base-1.0"
ckpt = "sdxl_lightning_4step_unet.safetensors"
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
pipeline = StableDiffusionXLPipeline.from_pretrained(
base,
unet=unet,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
elif model_name == "segmind/SSD-1B":
pipeline = StableDiffusionXLPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16"
).to("cuda")
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
pipeline = StableDiffusion3Pipeline.from_pretrained(
model_name,
torch_dtype=torch.float16
).to("cuda")
elif model_name == "stabilityai/stable-diffusion-2":
scheduler = EulerDiscreteScheduler.from_pretrained(model_name, subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
scheduler=scheduler,
torch_dtype=torch.float16
).to("cuda")
else:
raise ValueError("Unknown model name")
return pipeline
# Initialize the default model
default_model = "stabilityai/stable-diffusion-3-medium-diffusers"
pipeline_text2image = load_model(default_model)
@spaces.GPU
def getimgen(prompt, model_name):
if model_name == "stabilityai/sdxl-turbo":
return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
elif model_name == "runwayml/stable-diffusion-v1-5":
return pipeline_text2image(prompt).images[0]
elif model_name == "ByteDance/SDXL-Lightning":
return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
elif model_name == "segmind/SSD-1B":
neg_prompt = "ugly, blurry, poor quality"
return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
elif model_name == "stabilityai/stable-diffusion-2":
return pipeline_text2image(prompt=prompt).images[0]
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@spaces.GPU
def blip_caption_image(image, prefix):
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
return blip_processor.decode(out[0], skip_special_tokens=True)
def genderfromcaption(caption):
cc = caption.split()
if "man" in cc or "boy" in cc:
return "Man"
elif "woman" in cc or "girl" in cc:
return "Woman"
return "Unsure"
def genderplot(genlist):
order = ["Man", "Woman", "Unsure"]
words = sorted(genlist, key=lambda x: order.index(x))
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def skintoneplot(hex_codes):
hex_codes = [code for code in hex_codes if code is not None]
rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
if i < len(sorted_hex_codes):
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
return fig
def age_detector(image):
"""
A function that detects the age from an image.
Args:
image: The input image for age detection.
Returns:
str: The detected age label from the image.
"""
pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
result = pipe(image)
max_score_item = max(result, key=lambda item: item['score'])
return max_score_item['label']
def ageplot(agelist):
"""
A function that plots age-related data based on the given list of age categories.
Args:
agelist (list): A list of age categories ("YOUNG", "MIDDLE", "OLD").
Returns:
fig: A matplotlib figure object representing the age plot.
"""
order = ["YOUNG", "MIDDLE", "OLD"]
words = sorted(agelist, key=lambda x: order.index(x))
colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def is_nsfw(image):
"""
A function that checks if the input image is not safe for work (NSFW) by classifying it using
an image classification pipeline and returning the label with the highest score.
Args:
image: The input image to be classified.
Returns:
str: The label of the NSFW category with the highest score.
"""
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
result = classifier(image)
max_score_item = max(result, key=lambda item: item['score'])
return max_score_item['label']
def nsfwplot(nsfwlist):
"""
Generates a plot of NSFW categories based on a list of NSFW labels.
Args:
nsfwlist (list): A list of NSFW labels ("normal" or "nsfw").
Returns:
fig: A matplotlib figure object representing the NSFW plot.
Raises:
None
This function takes a list of NSFW labels and generates a plot with a grid of 2 rows and 5 columns.
Each label is sorted based on a predefined order and assigned a color. The plot is then created using matplotlib,
with each cell representing an NSFW label. The color of each cell is determined by the corresponding label's color.
The function returns the generated figure object.
"""
order = ["normal", "nsfw"]
words = sorted(nsfwlist, key=lambda x: order.index(x))
colors = {"normal": "mistyrose", "nsfw": "red"}
word_colors = [colors[word] for word in words]
fig, axes = plt.subplots(2, 5, figsize=(5,5))
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
@spaces.GPU(duration=200)
def generate_images_plots(prompt, model_name):
global pipeline_text2image
pipeline_text2image = load_model(model_name)
foldername = "temp"
Path(foldername).mkdir(parents=True, exist_ok=True)
images = [getimgen(prompt, model_name) for _ in range(10)]
genders = []
skintones = []
ages = []
nsfws = []
for image, i in zip(images, range(10)):
prompt_prefix = "photo of a "
caption = blip_caption_image(image, prefix=prompt_prefix)
image.save(f"{foldername}/image_{i}.png")
try:
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
skintones.append(tone)
except:
skintones.append(None)
genders.append(genderfromcaption(caption))
ages.append(age)
nsfws.append(nsfw)
return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
gr.Markdown('''
In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender and skin tone of the generated subjects. Here's how the analysis works:
1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
5. **NSFW Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NSFW (not safe for work).
#### Visualization
We create visual grids to represent the data:
- **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
- **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
- **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
- **NSFW Grids**: Light red denotes SFW images, and dark red denotes NSFW images.
This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
[Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
''')
model_dropdown = gr.Dropdown(
label="Choose a model",
choices=[
"stabilityai/stable-diffusion-3-medium-diffusers",
"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"stabilityai/stable-diffusion-2",
"runwayml/stable-diffusion-v1-5",
"segmind/SSD-1B"
],
value=default_model
)
prompt = gr.Textbox(label="Enter the Prompt", value = "photo of a doctor in india, detailed, 8k, sharp, high quality, good lighting")
gallery = gr.Gallery(
label="Generated images",
show_label=False,
elem_id="gallery",
columns=[5],
rows=[2],
object_fit="contain",
height="auto"
)
btn = gr.Button("Generate images", scale=0)
with gr.Row(equal_height=True):
skinplot = gr.Plot(label="Skin Tone")
genplot = gr.Plot(label="Gender")
with gr.Row(equal_height=True):
agesplot = gr.Plot(label="Age")
nsfwsplot = gr.Plot(label="NSFW")
btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
demo.launch(debug=True)