Adds Age Classifier and NSFW Classifier

#1
by lucianosb - opened
Files changed (1) hide show
  1. app.py +95 -4
app.py CHANGED
@@ -136,6 +136,87 @@ def skintoneplot(hex_codes):
136
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
137
  return fig
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  @spaces.GPU(duration=200)
140
  def generate_images_plots(prompt, model_name):
141
  global pipeline_text2image
@@ -145,6 +226,8 @@ def generate_images_plots(prompt, model_name):
145
  images = [getimgen(prompt, model_name) for _ in range(10)]
146
  genders = []
147
  skintones = []
 
 
148
  for image, i in zip(images, range(10)):
149
  prompt_prefix = "photo of a "
150
  caption = blip_caption_image(image, prefix=prompt_prefix)
@@ -156,7 +239,9 @@ def generate_images_plots(prompt, model_name):
156
  except:
157
  skintones.append(None)
158
  genders.append(genderfromcaption(caption))
159
- return images, skintoneplot(skintones), genderplot(genders)
 
 
160
 
161
  with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
162
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
@@ -166,7 +251,8 @@ In this demo, we explore the potential biases in text-to-image models by generat
166
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
167
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
168
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
169
-
 
170
 
171
  #### Visualization
172
 
@@ -174,7 +260,9 @@ We create visual grids to represent the data:
174
 
175
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
176
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
177
-
 
 
178
  This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
179
  [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
180
  ''')
@@ -204,6 +292,9 @@ This demo provides an insightful look into how current text-to-image models hand
204
  with gr.Row(equal_height=True):
205
  skinplot = gr.Plot(label="Skin Tone")
206
  genplot = gr.Plot(label="Gender")
207
- btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
 
 
 
208
 
209
  demo.launch(debug=True)
 
136
  ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
137
  return fig
138
 
139
+ def age_detector(image):
140
+ """
141
+ A function that detects the age from an image.
142
+
143
+ Args:
144
+ image: The input image for age detection.
145
+
146
+ Returns:
147
+ str: The detected age label from the image.
148
+ """
149
+ pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
150
+ result = pipe(image)
151
+ max_score_item = max(result, key=lambda item: item['score'])
152
+ return max_score_item['label']
153
+
154
+ def ageplot(agelist):
155
+ """
156
+ A function that plots age-related data based on the given list of age categories.
157
+
158
+ Args:
159
+ agelist (list): A list of age categories ("YOUNG", "MIDDLE", "OLD").
160
+
161
+ Returns:
162
+ fig: A matplotlib figure object representing the age plot.
163
+ """
164
+ order = ["YOUNG", "MIDDLE", "OLD"]
165
+ words = sorted(agelist, key=lambda x: order.index(x))
166
+ colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
167
+ word_colors = [colors[word] for word in words]
168
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
169
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
170
+ for i, ax in enumerate(axes.flat):
171
+ ax.set_axis_off()
172
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
173
+ return fig
174
+
175
+ def is_nsfw(image):
176
+ """
177
+ A function that checks if the input image is not safe for work (NSFW) by classifying it using
178
+ an image classification pipeline and returning the label with the highest score.
179
+
180
+ Args:
181
+ image: The input image to be classified.
182
+
183
+ Returns:
184
+ str: The label of the NSFW category with the highest score.
185
+ """
186
+ classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
187
+ result = classifier(image)
188
+ max_score_item = max(result, key=lambda item: item['score'])
189
+ return max_score_item['label']
190
+
191
+ def nsfwplot(nsfwlist):
192
+ """
193
+ Generates a plot of NSFW categories based on a list of NSFW labels.
194
+
195
+ Args:
196
+ nsfwlist (list): A list of NSFW labels ("normal" or "nsfw").
197
+
198
+ Returns:
199
+ fig: A matplotlib figure object representing the NSFW plot.
200
+
201
+ Raises:
202
+ None
203
+
204
+ This function takes a list of NSFW labels and generates a plot with a grid of 2 rows and 5 columns.
205
+ Each label is sorted based on a predefined order and assigned a color. The plot is then created using matplotlib,
206
+ with each cell representing an NSFW label. The color of each cell is determined by the corresponding label's color.
207
+ The function returns the generated figure object.
208
+ """
209
+ order = ["normal", "nsfw"]
210
+ words = sorted(nsfwlist, key=lambda x: order.index(x))
211
+ colors = {"normal": "mistyrose", "nsfw": "red"}
212
+ word_colors = [colors[word] for word in words]
213
+ fig, axes = plt.subplots(2, 5, figsize=(5,5))
214
+ plt.subplots_adjust(hspace=0.1, wspace=0.1)
215
+ for i, ax in enumerate(axes.flat):
216
+ ax.set_axis_off()
217
+ ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
218
+ return fig
219
+
220
  @spaces.GPU(duration=200)
221
  def generate_images_plots(prompt, model_name):
222
  global pipeline_text2image
 
226
  images = [getimgen(prompt, model_name) for _ in range(10)]
227
  genders = []
228
  skintones = []
229
+ ages = []
230
+ nsfws = []
231
  for image, i in zip(images, range(10)):
232
  prompt_prefix = "photo of a "
233
  caption = blip_caption_image(image, prefix=prompt_prefix)
 
239
  except:
240
  skintones.append(None)
241
  genders.append(genderfromcaption(caption))
242
+ ages.append(age)
243
+ nsfws.append(nsfw)
244
+ return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
245
 
246
  with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
247
  gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
 
251
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
252
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
253
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
254
+ 4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
255
+ 5. **NSFW Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NSFW (not safe for work).
256
 
257
  #### Visualization
258
 
 
260
 
261
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
262
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
263
+ - **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
264
+ - **NSFW Grids**: Light red denotes SFW images, and dark red denotes NSFW images.
265
+
266
  This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
267
  [Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
268
  ''')
 
292
  with gr.Row(equal_height=True):
293
  skinplot = gr.Plot(label="Skin Tone")
294
  genplot = gr.Plot(label="Gender")
295
+ with gr.Row(equal_height=True):
296
+ agesplot = gr.Plot(label="Age")
297
+ nsfwsplot = gr.Plot(label="NSFW")
298
+ btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
299
 
300
  demo.launch(debug=True)