Spaces:
Running
on
A100
Running
on
A100
import gradio as gr | |
import torch | |
from diffusers import AutoPipelineForText2Image | |
from transformers import BlipProcessor, BlipForConditionalGeneration | |
from pathlib import Path | |
import stone | |
import requests | |
import io | |
import os | |
from PIL import Image | |
import spaces | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from matplotlib.colors import hex2color | |
pipeline_text2image = AutoPipelineForText2Image.from_pretrained( | |
"stabilityai/sdxl-turbo", | |
torch_dtype=torch.float16, | |
variant="fp16", | |
) | |
pipeline_text2image = pipeline_text2image.to("cuda") | |
def getimgen(prompt): | |
return pipeline_text2image( | |
prompt=prompt, | |
guidance_scale=0.0, | |
num_inference_steps=2 | |
).images[0] | |
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large") | |
blip_model = BlipForConditionalGeneration.from_pretrained( | |
"Salesforce/blip-image-captioning-large", | |
torch_dtype=torch.float16 | |
).to("cuda") | |
def blip_caption_image(image, prefix): | |
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16) | |
out = blip_model.generate(**inputs) | |
return blip_processor.decode(out[0], skip_special_tokens=True) | |
def genderfromcaption(caption): | |
cc = caption.split() | |
if "man" in cc or "boy" in cc: | |
return "Man" | |
elif "woman" in cc or "girl" in cc: | |
return "Woman" | |
return "Unsure" | |
def genderplot(genlist): | |
order = ["Man", "Woman", "Unsure"] | |
# Sort the list based on the order of keys | |
words = sorted(genlist, key=lambda x: order.index(x)) | |
# Define colors for each category | |
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"} | |
# Map each word to its corresponding color | |
word_colors = [colors[word] for word in words] | |
# Plot the colors in a grid with reduced spacing | |
fig, axes = plt.subplots(2, 5, figsize=(5,5)) | |
# Adjust spacing between subplots | |
plt.subplots_adjust(hspace=0.1, wspace=0.1) | |
for i, ax in enumerate(axes.flat): | |
ax.set_axis_off() | |
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i])) | |
return fig | |
def skintoneplot(hex_codes): | |
# Convert hex codes to RGB values | |
rgb_values = [hex2color(hex_code) for hex_code in hex_codes] | |
# Calculate luminance for each color | |
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values] | |
# Sort hex codes based on luminance in descending order (dark to light) | |
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)] | |
# Plot the colors in a grid with reduced spacing | |
fig, axes = plt.subplots(2, 5, figsize=(5,5)) | |
# Adjust spacing between subplots | |
plt.subplots_adjust(hspace=0.1, wspace=0.1) | |
for i, ax in enumerate(axes.flat): | |
ax.set_axis_off() | |
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i])) | |
return fig | |
def generate_images_plots(prompt): | |
foldername = "temp" | |
# Generate 10 images | |
images = [getimgen(prompt) for _ in range(10)] | |
Path(foldername).mkdir(parents=True, exist_ok=True) | |
genders = [] | |
skintones = [] | |
for image, i in zip(images, range(10)): | |
prompt_prefix = "photo of a " | |
caption = blip_caption_image(image, prefix=prompt_prefix) | |
image.save(f"{foldername}/image_{i}.png") | |
try: | |
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False) | |
tone = skintoneres['faces'][0]['dominant_colors'][0]['color'] | |
skintones.append(tone) | |
except: | |
skintones.append(None) | |
genders.append(genderfromcaption(caption)) | |
print(genders, skintones) | |
return images, skintoneplot(skintones), genderplot(genders) | |
with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo: | |
gr.Markdown("# Skin Tone and Gender bias in SDXL Demo") | |
prompt = gr.Textbox(label="Enter the Prompt") | |
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", | |
columns=[5], rows=[2], object_fit="contain", height="auto") | |
btn = gr.Button("Generate images", scale=0) | |
with gr.Row(equal_height=True): | |
skinplot = gr.Plot(label="Skin Tone") | |
genplot = gr.Plot(label="Gender") | |
btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot]) | |
demo.launch(debug=True) | |