File size: 7,229 Bytes
956fa05
 
680331e
956fa05
 
31a0f6f
e7204ee
956fa05
 
 
31a0f6f
 
e7204ee
 
f1aa060
 
 
31a0f6f
f1aa060
31a0f6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680331e
f1aa060
17497eb
680331e
 
17497eb
f1aa060
680331e
31a0f6f
 
 
 
 
 
 
956fa05
e7204ee
 
680331e
e7204ee
de81f33
 
 
 
 
 
 
 
 
680331e
 
956fa05
64fe77f
31a0f6f
956fa05
e7204ee
956fa05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a0f6f
956fa05
 
 
 
 
 
 
31a0f6f
 
956fa05
 
e7204ee
f1aa060
e7204ee
fadf2e1
956fa05
 
e7204ee
956fa05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a0f6f
f56644b
31a0f6f
 
 
 
 
 
f93dac8
 
31a0f6f
 
 
956fa05
31a0f6f
 
 
 
 
 
 
 
 
956fa05
 
 
 
f1aa060
956fa05
e7204ee
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import gradio as gr
import torch
from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, EulerDiscreteScheduler, UNet2DConditionModel, StableDiffusion3Pipeline
from transformers import BlipProcessor, BlipForConditionalGeneration
from pathlib import Path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.colors import hex2color
import stone
import os
import spaces


api_key = os.getenv("AccessTokenSD3")

# Define model initialization functions
def load_model(model_name):
    if model_name == "stabilityai/sdxl-turbo":
        pipeline = DiffusionPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            variant="fp16"
        ).to("cuda")
    elif model_name == "runwayml/stable-diffusion-v1-5":
        pipeline = StableDiffusionPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16
        ).to("cuda")
    elif model_name == "ByteDance/SDXL-Lightning":
        base = "stabilityai/stable-diffusion-xl-base-1.0"
        ckpt = "sdxl_lightning_4step_unet.safetensors"
        unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
        unet.load_state_dict(load_file(hf_hub_download(model_name, ckpt), device="cuda"))
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            base, 
            unet=unet, 
            torch_dtype=torch.float16, 
            variant="fp16"
        ).to("cuda")
        pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing")
    elif model_name == "segmind/SSD-1B":
        pipeline = StableDiffusionXLPipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16, 
            use_safetensors=True, 
            variant="fp16"
        ).to("cuda")
    elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
        if api_key is None:
            raise ValueError("Hugging Face token is required to access this model")
        pipeline = StableDiffusion3Pipeline.from_pretrained(
            model_name, 
            torch_dtype=torch.float16,
            use_auth_token=api_key
        ).to("cuda")
    else:
        raise ValueError("Unknown model name")
    return pipeline

# Initialize the default model
default_model = "stabilityai/sdxl-turbo"
pipeline_text2image = load_model(default_model)

@spaces.GPU
def getimgen(prompt, model_name):
    global pipeline_text2image
    pipeline_text2image = load_model(model_name)
    if model_name == "stabilityai/sdxl-turbo":
        return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
    elif model_name == "runwayml/stable-diffusion-v1-5":
        return pipeline_text2image(prompt).images[0]
    elif model_name == "ByteDance/SDXL-Lightning":
        return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
    elif model_name == "segmind/SSD-1B":
        neg_prompt = "ugly, blurry, poor quality"
        return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
    elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
        return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")

@spaces.GPU
def blip_caption_image(image, prefix):
    inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
    out = blip_model.generate(**inputs)
    return blip_processor.decode(out[0], skip_special_tokens=True)

def genderfromcaption(caption):
    cc = caption.split()
    if "man" in cc or "boy" in cc:
        return "Man"
    elif "woman" in cc or "girl" in cc:
        return "Woman"
    return "Unsure"

def genderplot(genlist):    
    order = ["Man", "Woman", "Unsure"]
    words = sorted(genlist, key=lambda x: order.index(x))
    colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
    word_colors = [colors[word] for word in words]
    fig, axes = plt.subplots(2, 5, figsize=(5,5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        ax.set_axis_off()
        ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
    return fig

def skintoneplot(hex_codes):
    hex_codes = [code for code in hex_codes if code is not None]
    rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
    luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
    sorted_hex_codes = [code for _, code in sorted(zip(luminance_values, hex_codes), reverse=True)]
    fig, axes = plt.subplots(2, 5, figsize=(5,5))
    plt.subplots_adjust(hspace=0.1, wspace=0.1)
    for i, ax in enumerate(axes.flat):
        ax.set_axis_off()
        if i < len(sorted_hex_codes):
            ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
    return fig

@spaces.GPU
def generate_images_plots(prompt, model_name):
    global pipeline_text2image
    pipeline_text2image = load_model(model_name, token)
    foldername = "temp"
    Path(foldername).mkdir(parents=True, exist_ok=True)
    images = [getimgen(prompt, model_name) for _ in range(10)]
    genders = []
    skintones = []
    for image, i in zip(images, range(10)):
        prompt_prefix = "photo of a "
        caption = blip_caption_image(image, prefix=prompt_prefix)
        image.save(f"{foldername}/image_{i}.png")
        try:
            skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
            tone = skintoneres['faces'][0]['dominant_colors'][0]['color']
            skintones.append(tone)
        except:
            skintones.append(None)
        genders.append(genderfromcaption(caption))
    return images, skintoneplot(skintones), genderplot(genders)

with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as demo:
    gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
    model_dropdown = gr.Dropdown(
        label="Choose a model", 
        choices=[
            "stabilityai/sdxl-turbo", 
            "runwayml/stable-diffusion-v1-5", 
            "ByteDance/SDXL-Lightning", 
            "segmind/SSD-1B",
            "stabilityai/stable-diffusion-3-medium-diffusers"
        ], 
        value=default_model
    )
    prompt = gr.Textbox(label="Enter the Prompt")
    gallery = gr.Gallery(
        label="Generated images", 
        show_label=False, 
        elem_id="gallery", 
        columns=[5], 
        rows=[2], 
        object_fit="contain", 
        height="auto"
    )
    btn = gr.Button("Generate images", scale=0)
    with gr.Row(equal_height=True):
        skinplot = gr.Plot(label="Skin Tone")
        genplot = gr.Plot(label="Gender")
    btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])

demo.launch(debug=True)