lucianosb commited on
Commit
061d3a6
·
verified ·
1 Parent(s): db9e2d6

Adds two more bias detectors

Browse files
Files changed (1) hide show
  1. app.py +195 -8
app.py CHANGED
@@ -8,7 +8,7 @@ from diffusers import (
8
  UNet2DConditionModel,
9
  StableDiffusion3Pipeline
10
  )
11
- from transformers import BlipProcessor, BlipForConditionalGeneration
12
  from pathlib import Path
13
  from safetensors.torch import load_file
14
  from huggingface_hub import hf_hub_download
@@ -28,6 +28,19 @@ login(token = access_token)
28
 
29
  # Define model initialization functions
30
  def load_model(model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  if model_name == "sinteticoXL":
32
  pipeline = StableDiffusionXLPipeline.from_single_file(
33
  "https://huggingface.co/lucianosb/sinteticoXL-models/blob/main/sinteticoXL_v1dot2.safetensors",
@@ -52,6 +65,16 @@ pipeline_text2image = load_model(default_model)
52
 
53
  @spaces.GPU
54
  def getimgen(prompt, model_name):
 
 
 
 
 
 
 
 
 
 
55
  if model_name == "sinteticoXL":
56
  return pipeline_text2image(prompt=prompt, guidance_scale=6.0, num_inference_steps=20).images[0]
57
  elif model_name == "sinteticoXL_Prude":
@@ -62,11 +85,30 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
62
 
63
  @spaces.GPU
64
  def blip_caption_image(image, prefix):
 
 
 
 
 
 
 
 
 
 
65
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
66
  out = blip_model.generate(**inputs)
67
  return blip_processor.decode(out[0], skip_special_tokens=True)
68
 
69
  def genderfromcaption(caption):
 
 
 
 
 
 
 
 
 
70
  cc = caption.split()
71
  if "man" in cc or "boy" in cc:
72
  return "Man"
@@ -75,6 +117,15 @@ def genderfromcaption(caption):
75
  return "Unsure"
76
 
77
  def genderplot(genlist):
 
 
 
 
 
 
 
 
 
78
  order = ["Man", "Woman", "Unsure"]
79
  words = sorted(genlist, key=lambda x: order.index(x))
80
  colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
@@ -86,7 +137,107 @@ def genderplot(genlist):
86
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
87
  return fig
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def skintoneplot(hex_codes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  hex_codes = [code for code in hex_codes if code is not None]
91
  rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
92
  luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
@@ -101,6 +252,20 @@ def skintoneplot(hex_codes):
101
 
102
  @spaces.GPU(duration=200)
103
  def generate_images_plots(prompt, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  global pipeline_text2image
105
  pipeline_text2image = load_model(model_name)
106
  foldername = "temp"
@@ -108,9 +273,13 @@ def generate_images_plots(prompt, model_name):
108
  images = [getimgen(prompt, model_name) for _ in range(10)]
109
  genders = []
110
  skintones = []
 
 
111
  for image, i in zip(images, range(10)):
112
  prompt_prefix = "photo of a "
113
  caption = blip_caption_image(image, prefix=prompt_prefix)
 
 
114
  image.save(f"{foldername}/image_{i}.png")
115
  try:
116
  skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
@@ -119,31 +288,46 @@ def generate_images_plots(prompt, model_name):
119
  except:
120
  skintones.append(None)
121
  genders.append(genderfromcaption(caption))
122
- return images, skintoneplot(skintones), genderplot(genders)
 
 
123
 
124
- with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
125
- gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
126
  gr.Markdown('''
127
- In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender and skin tone of the generated subjects. Here's how the analysis works:
128
 
129
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
130
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
131
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
 
 
132
 
133
  ## Models
134
 
135
  - Sintetico XL: a merged model with my favorite aesthetics
136
  - Sintetico XL Prude: a SFW version that aims to remove unwanted nudity and sexual content.
137
 
 
 
 
 
 
 
 
 
 
 
 
138
  #### Visualization
139
 
140
  We create visual grids to represent the data:
141
 
142
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
143
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
 
 
144
 
145
- This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
146
- [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
147
  ''')
148
  model_dropdown = gr.Dropdown(
149
  label="Choose a model",
@@ -167,6 +351,9 @@ This demo provides an insightful look into how current text-to-image models hand
167
  with gr.Row(equal_height=True):
168
  skinplot = gr.Plot(label="Skin Tone")
169
  genplot = gr.Plot(label="Gender")
170
- btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
 
 
 
171
 
172
  demo.launch(debug=True)
 
8
  UNet2DConditionModel,
9
  StableDiffusion3Pipeline
10
  )
11
+ from transformers import BlipProcessor, BlipForConditionalGeneration, pipeline
12
  from pathlib import Path
13
  from safetensors.torch import load_file
14
  from huggingface_hub import hf_hub_download
 
28
 
29
  # Define model initialization functions
30
  def load_model(model_name):
31
+ """
32
+ Load a StableDiffusionXLPipeline from a single file.
33
+
34
+ Args:
35
+ model_name (str): The name of the model.
36
+
37
+ Returns:
38
+ StableDiffusionXLPipeline: The loaded pipeline.
39
+
40
+ Raises:
41
+ ValueError: If the model name is unknown.
42
+
43
+ """
44
  if model_name == "sinteticoXL":
45
  pipeline = StableDiffusionXLPipeline.from_single_file(
46
  "https://huggingface.co/lucianosb/sinteticoXL-models/blob/main/sinteticoXL_v1dot2.safetensors",
 
65
 
66
  @spaces.GPU
67
  def getimgen(prompt, model_name):
68
+ """
69
+ This function generates an image based on the prompt and the specified model name.
70
+
71
+ Args:
72
+ prompt (str): The input prompt for generating the image.
73
+ model_name (str): The name of the model to use for image generation.
74
+
75
+ Returns:
76
+ Image: The generated image based on the prompt and model.
77
+ """
78
  if model_name == "sinteticoXL":
79
  return pipeline_text2image(prompt=prompt, guidance_scale=6.0, num_inference_steps=20).images[0]
80
  elif model_name == "sinteticoXL_Prude":
 
85
 
86
  @spaces.GPU
87
  def blip_caption_image(image, prefix):
88
+ """
89
+ This function generates a caption for the input image based on the provided prefix.
90
+
91
+ Args:
92
+ image: The input image for which the caption is generated.
93
+ prefix: The prefix used for caption generation.
94
+
95
+ Returns:
96
+ str: The generated caption for the input image.
97
+ """
98
  inputs = blip_processor(image, prefix, return_tensors="pt").to("cuda", torch.float16)
99
  out = blip_model.generate(**inputs)
100
  return blip_processor.decode(out[0], skip_special_tokens=True)
101
 
102
  def genderfromcaption(caption):
103
+ """
104
+ A function that determines the gender based on the input caption.
105
+
106
+ Args:
107
+ caption (str): The caption for which the gender needs to be determined.
108
+
109
+ Returns:
110
+ str: The gender identified from the caption (either "Man", "Woman", or "Unsure").
111
+ """
112
  cc = caption.split()
113
  if "man" in cc or "boy" in cc:
114
  return "Man"
 
117
  return "Unsure"
118
 
119
  def genderplot(genlist):
120
+ """
121
+ A function that plots gender-related data based on the given list of genders.
122
+
123
+ Args:
124
+ genlist (list): A list of gender labels ("Man", "Woman", or "Unsure").
125
+
126
+ Returns:
127
+ fig: A matplotlib figure object representing the gender plot.
128
+ """
129
  order = ["Man", "Woman", "Unsure"]
130
  words = sorted(genlist, key=lambda x: order.index(x))
131
  colors = {"Man": "lightgreen", "Woman": "darkgreen", "Unsure": "lightgrey"}
 
137
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
138
  return fig
139
 
140
+ def age_detector(image):
141
+ """
142
+ A function that detects the age from an image.
143
+
144
+ Args:
145
+ image: The input image for age detection.
146
+
147
+ Returns:
148
+ str: The detected age label from the image.
149
+ """
150
+ pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
151
+ result = pipe(image)
152
+ max_score_item = max(result, key=lambda item: item['score'])
153
+ return max_score_item['label']
154
+
155
+ def ageplot(agelist):
156
+ """
157
+ A function that plots age-related data based on the given list of age categories.
158
+
159
+ Args:
160
+ agelist (list): A list of age categories ("YOUNG", "MIDDLE", "OLD").
161
+
162
+ Returns:
163
+ fig: A matplotlib figure object representing the age plot.
164
+ """
165
+ order = ["YOUNG", "MIDDLE", "OLD"]
166
+ words = sorted(agelist, key=lambda x: order.index(x))
167
+ colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
168
+ word_colors = [colors[word] for word in words]
169
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
170
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
171
+ for i, ax in enumerate(axes.flat):
172
+ ax.set_axis_off()
173
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
174
+ return fig
175
+
176
+ def is_nsfw(image):
177
+ """
178
+ A function that checks if the input image is not safe for work (NSFW) by classifying it using
179
+ an image classification pipeline and returning the label with the highest score.
180
+
181
+ Args:
182
+ image: The input image to be classified.
183
+
184
+ Returns:
185
+ str: The label of the NSFW category with the highest score.
186
+ """
187
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
188
+ result = classifier(image)
189
+ max_score_item = max(result, key=lambda item: item['score'])
190
+ return max_score_item['label']
191
+
192
+ def nsfwplot(nsfwlist):
193
+ """
194
+ Generates a plot of NSFW categories based on a list of NSFW labels.
195
+
196
+ Args:
197
+ nsfwlist (list): A list of NSFW labels ("normal" or "nsfw").
198
+
199
+ Returns:
200
+ fig: A matplotlib figure object representing the NSFW plot.
201
+
202
+ Raises:
203
+ None
204
+
205
+ This function takes a list of NSFW labels and generates a plot with a grid of 2 rows and 5 columns.
206
+ Each label is sorted based on a predefined order and assigned a color. The plot is then created using matplotlib,
207
+ with each cell representing an NSFW label. The color of each cell is determined by the corresponding label's color.
208
+ The function returns the generated figure object.
209
+ """
210
+ order = ["normal", "nsfw"]
211
+ words = sorted(nsfwlist, key=lambda x: order.index(x))
212
+ colors = {"normal": "mistyrose", "nsfw": "red"}
213
+ word_colors = [colors[word] for word in words]
214
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
215
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
216
+ for i, ax in enumerate(axes.flat):
217
+ ax.set_axis_off()
218
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
219
+ return fig
220
+
221
  def skintoneplot(hex_codes):
222
+ """
223
+ Generates a plot of skin tones based on a list of hexadecimal color codes.
224
+
225
+ Args:
226
+ hex_codes (list): A list of hexadecimal color codes.
227
+
228
+ Returns:
229
+ fig: A matplotlib figure object representing the skin tone plot.
230
+
231
+ Raises:
232
+ None
233
+
234
+ This function takes a list of hexadecimal color codes and generates a plot with a grid of 2 rows and 5 columns.
235
+ Each color code is converted to its corresponding RGB value and then the luminance value is calculated using the
236
+ formula: luminance = 0.299 * R + 0.587 * G + 0.114 * B. The colors are then sorted based on their luminance values
237
+ in descending order and assigned to the corresponding cells in the plot. The plot is created using matplotlib,
238
+ with each cell representing a skin tone. The color of each cell is determined by the corresponding skin tone color.
239
+ The function returns the generated figure object.
240
+ """
241
  hex_codes = [code for code in hex_codes if code is not None]
242
  rgb_values = [hex2color(hex_code) for hex_code in hex_codes]
243
  luminance_values = [0.299 * r + 0.587 * g + 0.114 * b for r, g, b in rgb_values]
 
252
 
253
  @spaces.GPU(duration=200)
254
  def generate_images_plots(prompt, model_name):
255
+ """
256
+ This function generates images, extracts information like genders, skintones, ages, and nsfw labels from the images,
257
+ and returns a tuple containing the images and plots for skintones, genders, ages, and nsfw labels.
258
+
259
+ Args:
260
+ prompt (str): The prompt for generating images.
261
+ model_name (str): The name of the model.
262
+
263
+ Returns:
264
+ tuple: A tuple containing the images generated, skintone plot, gender plot, age plot, and nsfw plot.
265
+
266
+ Raises:
267
+ None
268
+ """
269
  global pipeline_text2image
270
  pipeline_text2image = load_model(model_name)
271
  foldername = "temp"
 
273
  images = [getimgen(prompt, model_name) for _ in range(10)]
274
  genders = []
275
  skintones = []
276
+ ages = []
277
+ nsfws = []
278
  for image, i in zip(images, range(10)):
279
  prompt_prefix = "photo of a "
280
  caption = blip_caption_image(image, prefix=prompt_prefix)
281
+ age = age_detector(image)
282
+ nsfw = is_nsfw(image)
283
  image.save(f"{foldername}/image_{i}.png")
284
  try:
285
  skintoneres = stone.process(f"{foldername}/image_{i}.png", return_report_image=False)
 
288
  except:
289
  skintones.append(None)
290
  genders.append(genderfromcaption(caption))
291
+ ages.append(age)
292
+ nsfws.append(nsfw)
293
+ return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
294
 
295
+ with gr.Blocks(title="Bias detection in SinteticoXL Models") as demo:
296
+ gr.Markdown("# Bias detection in SinteticoXL Models")
297
  gr.Markdown('''
298
+ In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender, skin tone, and age of the generated subjects as well as the potential for NSFW content. Here's how the analysis works:
299
 
300
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
301
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
302
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
303
+ 4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
304
+ 5. **NSFW Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NSFW (not safe for work).
305
 
306
  ## Models
307
 
308
  - Sintetico XL: a merged model with my favorite aesthetics
309
  - Sintetico XL Prude: a SFW version that aims to remove unwanted nudity and sexual content.
310
 
311
+
312
+ ''')
313
+ with gr.Accordion("Open for More Information!", open=False):
314
+ gr.Markdown('''
315
+ This space was clone from [JournalistsonHF/text-to-image-bias](https://huggingface.co/spaces/JournalistsonHF/text-to-image-bias).
316
+
317
+ 👉 It's also in line with "Stable Bias" work by Hugging Face's ML & Society team: https://huggingface.co/spaces/society-ethics/StableBias
318
+
319
+ This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
320
+ [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
321
+
322
  #### Visualization
323
 
324
  We create visual grids to represent the data:
325
 
326
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
327
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
328
+ - **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
329
+ - **NSFW Grids**: Light red denotes SFW images, and dark red denotes NSFW images.
330
 
 
 
331
  ''')
332
  model_dropdown = gr.Dropdown(
333
  label="Choose a model",
 
351
  with gr.Row(equal_height=True):
352
  skinplot = gr.Plot(label="Skin Tone")
353
  genplot = gr.Plot(label="Gender")
354
+ with gr.Row(equal_height=True):
355
+ agesplot = gr.Plot(label="Age")
356
+ nsfwsplot = gr.Plot(label="NSFW")
357
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
358
 
359
  demo.launch(debug=True)