Spaces:
Running
on
Zero
Running
on
Zero
Adds Age Classifier and NSFW Classifier
Browse filesThe new functions use https://huggingface.co/dima806/faces_age_detection and https://huggingface.co/Falconsai/nsfw_image_detection models.
app.py
CHANGED
@@ -136,6 +136,87 @@ def skintoneplot(hex_codes):
|
|
136 |
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
|
137 |
return fig
|
138 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
@spaces.GPU(duration=200)
|
140 |
def generate_images_plots(prompt, model_name):
|
141 |
global pipeline_text2image
|
@@ -145,6 +226,8 @@ def generate_images_plots(prompt, model_name):
|
|
145 |
images = [getimgen(prompt, model_name) for _ in range(10)]
|
146 |
genders = []
|
147 |
skintones = []
|
|
|
|
|
148 |
for image, i in zip(images, range(10)):
|
149 |
prompt_prefix = "photo of a "
|
150 |
caption = blip_caption_image(image, prefix=prompt_prefix)
|
@@ -156,7 +239,9 @@ def generate_images_plots(prompt, model_name):
|
|
156 |
except:
|
157 |
skintones.append(None)
|
158 |
genders.append(genderfromcaption(caption))
|
159 |
-
|
|
|
|
|
160 |
|
161 |
with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
|
162 |
gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
|
@@ -166,7 +251,8 @@ In this demo, we explore the potential biases in text-to-image models by generat
|
|
166 |
1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
|
167 |
2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
|
168 |
3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
|
169 |
-
|
|
|
170 |
|
171 |
#### Visualization
|
172 |
|
@@ -174,7 +260,9 @@ We create visual grids to represent the data:
|
|
174 |
|
175 |
- **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
|
176 |
- **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
|
177 |
-
|
|
|
|
|
178 |
This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
|
179 |
[Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
|
180 |
''')
|
@@ -204,6 +292,9 @@ This demo provides an insightful look into how current text-to-image models hand
|
|
204 |
with gr.Row(equal_height=True):
|
205 |
skinplot = gr.Plot(label="Skin Tone")
|
206 |
genplot = gr.Plot(label="Gender")
|
207 |
-
|
|
|
|
|
|
|
208 |
|
209 |
demo.launch(debug=True)
|
|
|
136 |
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=sorted_hex_codes[i]))
|
137 |
return fig
|
138 |
|
139 |
+
def age_detector(image):
|
140 |
+
"""
|
141 |
+
A function that detects the age from an image.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
image: The input image for age detection.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
str: The detected age label from the image.
|
148 |
+
"""
|
149 |
+
pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
|
150 |
+
result = pipe(image)
|
151 |
+
max_score_item = max(result, key=lambda item: item['score'])
|
152 |
+
return max_score_item['label']
|
153 |
+
|
154 |
+
def ageplot(agelist):
|
155 |
+
"""
|
156 |
+
A function that plots age-related data based on the given list of age categories.
|
157 |
+
|
158 |
+
Args:
|
159 |
+
agelist (list): A list of age categories ("YOUNG", "MIDDLE", "OLD").
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
fig: A matplotlib figure object representing the age plot.
|
163 |
+
"""
|
164 |
+
order = ["YOUNG", "MIDDLE", "OLD"]
|
165 |
+
words = sorted(agelist, key=lambda x: order.index(x))
|
166 |
+
colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
|
167 |
+
word_colors = [colors[word] for word in words]
|
168 |
+
fig, axes = plt.subplots(2, 5, figsize=(5,5))
|
169 |
+
plt.subplots_adjust(hspace=0.1, wspace=0.1)
|
170 |
+
for i, ax in enumerate(axes.flat):
|
171 |
+
ax.set_axis_off()
|
172 |
+
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
|
173 |
+
return fig
|
174 |
+
|
175 |
+
def is_nsfw(image):
|
176 |
+
"""
|
177 |
+
A function that checks if the input image is not safe for work (NSFW) by classifying it using
|
178 |
+
an image classification pipeline and returning the label with the highest score.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
image: The input image to be classified.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
str: The label of the NSFW category with the highest score.
|
185 |
+
"""
|
186 |
+
classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
|
187 |
+
result = classifier(image)
|
188 |
+
max_score_item = max(result, key=lambda item: item['score'])
|
189 |
+
return max_score_item['label']
|
190 |
+
|
191 |
+
def nsfwplot(nsfwlist):
|
192 |
+
"""
|
193 |
+
Generates a plot of NSFW categories based on a list of NSFW labels.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
nsfwlist (list): A list of NSFW labels ("normal" or "nsfw").
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
fig: A matplotlib figure object representing the NSFW plot.
|
200 |
+
|
201 |
+
Raises:
|
202 |
+
None
|
203 |
+
|
204 |
+
This function takes a list of NSFW labels and generates a plot with a grid of 2 rows and 5 columns.
|
205 |
+
Each label is sorted based on a predefined order and assigned a color. The plot is then created using matplotlib,
|
206 |
+
with each cell representing an NSFW label. The color of each cell is determined by the corresponding label's color.
|
207 |
+
The function returns the generated figure object.
|
208 |
+
"""
|
209 |
+
order = ["normal", "nsfw"]
|
210 |
+
words = sorted(nsfwlist, key=lambda x: order.index(x))
|
211 |
+
colors = {"normal": "mistyrose", "nsfw": "red"}
|
212 |
+
word_colors = [colors[word] for word in words]
|
213 |
+
fig, axes = plt.subplots(2, 5, figsize=(5,5))
|
214 |
+
plt.subplots_adjust(hspace=0.1, wspace=0.1)
|
215 |
+
for i, ax in enumerate(axes.flat):
|
216 |
+
ax.set_axis_off()
|
217 |
+
ax.add_patch(plt.Rectangle((0, 0), 1, 1, color=word_colors[i]))
|
218 |
+
return fig
|
219 |
+
|
220 |
@spaces.GPU(duration=200)
|
221 |
def generate_images_plots(prompt, model_name):
|
222 |
global pipeline_text2image
|
|
|
226 |
images = [getimgen(prompt, model_name) for _ in range(10)]
|
227 |
genders = []
|
228 |
skintones = []
|
229 |
+
ages = []
|
230 |
+
nsfws = []
|
231 |
for image, i in zip(images, range(10)):
|
232 |
prompt_prefix = "photo of a "
|
233 |
caption = blip_caption_image(image, prefix=prompt_prefix)
|
|
|
239 |
except:
|
240 |
skintones.append(None)
|
241 |
genders.append(genderfromcaption(caption))
|
242 |
+
ages.append(age)
|
243 |
+
nsfws.append(nsfw)
|
244 |
+
return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
|
245 |
|
246 |
with gr.Blocks(title="Skin Tone and Gender bias in Text-to-Image Generation Models") as demo:
|
247 |
gr.Markdown("# Skin Tone and Gender bias in Text to Image Models")
|
|
|
251 |
1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
|
252 |
2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
|
253 |
3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
|
254 |
+
4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
|
255 |
+
5. **NSFW Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NSFW (not safe for work).
|
256 |
|
257 |
#### Visualization
|
258 |
|
|
|
260 |
|
261 |
- **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
|
262 |
- **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
|
263 |
+
- **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
|
264 |
+
- **NSFW Grids**: Light red denotes SFW images, and dark red denotes NSFW images.
|
265 |
+
|
266 |
This demo provides an insightful look into how current text-to-image models handle sensitive attributes, shedding light on areas for improvement and further study.
|
267 |
[Here is an article](https://medium.com/@evijit/analysis-of-ai-generated-images-of-indian-people-for-colorism-and-sexism-b80ff946759f) showing how this space can be used to perform such analyses, using colorism and sexism in India as an example.
|
268 |
''')
|
|
|
292 |
with gr.Row(equal_height=True):
|
293 |
skinplot = gr.Plot(label="Skin Tone")
|
294 |
genplot = gr.Plot(label="Gender")
|
295 |
+
with gr.Row(equal_height=True):
|
296 |
+
agesplot = gr.Plot(label="Age")
|
297 |
+
nsfwsplot = gr.Plot(label="NSFW")
|
298 |
+
btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot, agesplot, nsfwsplot])
|
299 |
|
300 |
demo.launch(debug=True)
|