text-to-image-bias / app copy.py
Avijit Ghosh
playing around with model options
f56644b
raw
history blame
4.47 kB
import gradio as gr
import torch
from diffusers import AutoPipelineForText2Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
import stone
import requests
import io
import os
from PIL import Image
import spaces
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import hex2color
pipeline_text2image = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch.float16,
variant="fp16",
)
pipeline_text2image = pipeline_text2image.to("cuda")
@spaces.GPU
def getimgen(prompt):
return pipeline_text2image(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=2
).images[0]
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large",
torch_dtype=torch.float16
).to("cuda")
@spaces.GPU
def blip_caption_image(image, prefix):
inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
return blip_processor.decode(out[0], skip_special_tokens=True)
def genderfromcaption(caption):
cc = caption.split()
if "man" in cc or "boy" in cc:
return "Man"
elif "woman" in cc or "girl" in cc:
return "Woman"
return "Unsure"
def genderplot(genlist):
order = ["Man", "Woman", "Unsure"]
# Sort the list based on the order of keys
words = sorted(genlist, key=lambda x: order.index(x))
# Define colors for each category
colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
# Map each word to its corresponding color
word_colors = [colors[word] for word in words]
# Plot the colors in a grid with reduced spacing
fig, axes = plt.subplots(2, 5, figsize=(5,5))
# Adjust spacing between subplots
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
return fig
def skintoneplot(hex_codes):
# Convert hex codes to RGB values
rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
# Calculate luminance for each color
luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
# Sort hex codes based on luminance in descending order (dark to light)
sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
# Plot the colors in a grid with reduced spacing
fig, axes = plt.subplots(2, 5, figsize=(5,5))
# Adjust spacing between subplots
plt.subplots_adjust(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.set_axis_off()
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
return fig
@spaces.GPU
def generate_images_plots(prompt):
foldername = "temp"
# Generate 10 images
images = [getimgen(prompt) for _ in range(10)]
Path(foldername).mkdir(parents=True, exist_ok=True)
genders = []
skintones = []
for image, i in zip(images, range(10)):
prompt_prefix = "photo of a "
caption = blip_caption_image(image, prefix=prompt_prefix)
image.save(f"{foldername}/image_{i}.png")
try:
skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
skintones.append(tone)
except:
skintones.append(None)
genders.append(genderfromcaption(caption))
print(genders, skintones)
return images, skintoneplot(skintones), genderplot(genders)
with gr.Blocks(title = "Skin Tone and Gender bias in SDXL Demo - Inference API") as demo:
gr.Markdown("# Skin Tone and Gender bias in SDXL Demo")
prompt = gr.Textbox(label="Enter the Prompt")
gallery = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery",
columns=[5], rows=[2], object_fit="contain", height="auto")
btn = gr.Button("Generate images", scale=0)
with gr.Row(equal_height=True):
skinplot = gr.Plot(label="Skin Tone")
genplot = gr.Plot(label="Gender")
btn.click(generate_images_plots, inputs = prompt, outputs = [gallery, skinplot, genplot])
demo.launch(debug=True)