evijit HF staff commited on
Commit
a1124c1
·
verified ·
1 Parent(s): 1b7f205

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -60
app.py CHANGED
@@ -6,7 +6,8 @@ from diffusers import (
6
  StableDiffusionXLPipeline,
7
  EulerDiscreteScheduler,
8
  UNet2DConditionModel,
9
- StableDiffusion3Pipeline
 
10
  )
11
  from transformers import BlipProcessor, BlipForConditionalGeneration
12
  from pathlib import Path
@@ -21,11 +22,9 @@ import spaces
21
 
22
  access_token = os.getenv("AccessTokenSD3")
23
 
24
-
25
  from huggingface_hub import login
26
  login(token = access_token)
27
 
28
-
29
  # Define model initialization functions
30
  def load_model(model_name):
31
  if model_name == "stabilityai/sdxl-turbo":
@@ -65,6 +64,9 @@ def load_model(model_name):
65
  scheduler=scheduler,
66
  torch_dtype=torch.float16
67
  ).to("cuda")
 
 
 
68
  else:
69
  raise ValueError("Unknown model name")
70
  return pipeline
@@ -76,16 +78,26 @@ pipeline_text2image = load_model(default_model)
76
  @spaces.GPU
77
  def getimgen(prompt, model_name):
78
  if model_name == "stabilityai/sdxl-turbo":
79
- return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
80
  elif model_name == "ByteDance/SDXL-Lightning":
81
- return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
82
  elif model_name == "segmind/SSD-1B":
83
  neg_prompt = "ugly, blurry, poor quality"
84
- return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
85
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
86
- return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0).images[0]
87
  elif model_name == "stabilityai/stable-diffusion-2":
88
- return pipeline_text2image(prompt=prompt).images[0]
 
 
 
 
 
 
 
 
 
 
89
 
90
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
91
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@@ -130,30 +142,12 @@ def skintoneplot(hex_codes):
130
  return fig
131
 
132
  def age_detector(image):
133
- """
134
- A function that detects the age from an image.
135
-
136
- Args:
137
- image: The input image for age detection.
138
-
139
- Returns:
140
- str: The detected age label from the image.
141
- """
142
  pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
143
  result = pipe(image)
144
  max_score_item = max(result, key=lambda item: item['score'])
145
  return max_score_item['label']
146
 
147
  def ageplot(agelist):
148
- """
149
- A function that plots age-related data based on the given list of age categories.
150
-
151
- Args:
152
- agelist (list): A list of age categories ("YOUNG", "MIDDLE", "OLD").
153
-
154
- Returns:
155
- fig: A matplotlib figure object representing the age plot.
156
- """
157
  order = ["YOUNG", "MIDDLE", "OLD"]
158
  words = sorted(agelist, key=lambda x: order.index(x))
159
  colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
@@ -166,39 +160,12 @@ def ageplot(agelist):
166
  return fig
167
 
168
  def is_nsfw(image):
169
- """
170
- A function that checks if the input image is not for all audiences (NFAA) by classifying it using
171
- an image classification pipeline and returning the label with the highest score.
172
-
173
- Args:
174
- image: The input image to be classified.
175
-
176
- Returns:
177
- str: The label of the NFAA category with the highest score.
178
- """
179
  classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
180
  result = classifier(image)
181
  max_score_item = max(result, key=lambda item: item['score'])
182
  return max_score_item['label']
183
 
184
  def nsfwplot(nsfwlist):
185
- """
186
- Generates a plot of NFAA categories based on a list of NFAA labels.
187
-
188
- Args:
189
- nsfwlist (list): A list of NSFW labels ("normal" or "nsfw").
190
-
191
- Returns:
192
- fig: A matplotlib figure object representing the NSFW plot.
193
-
194
- Raises:
195
- None
196
-
197
- This function takes a list of NFAA labels and generates a plot with a grid of 2 rows and 5 columns.
198
- Each label is sorted based on a predefined order and assigned a color. The plot is then created using matplotlib,
199
- with each cell representing an NFAA label. The color of each cell is determined by the corresponding label's color.
200
- The function returns the generated figure object.
201
- """
202
  order = ["normal", "nsfw"]
203
  words = sorted(nsfwlist, key=lambda x: order.index(x))
204
  colors = {"normal": "mistyrose", "nsfw": "red"}
@@ -232,25 +199,21 @@ def generate_images_plots(prompt, model_name):
232
  except:
233
  skintones.append(None)
234
  genders.append(genderfromcaption(caption))
235
- ages.append(age_detector(image)) # Call age_detector function
236
- nsfws.append(is_nsfw(image)) # Call is_nsfw function
237
  return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
238
 
239
  with gr.Blocks(title="Demographic bias in Text-to-Image Generation Models") as demo:
240
  gr.Markdown("# Demographic bias in Text to Image Models")
241
  gr.Markdown('''
242
  In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender, skin tone, age, and potential sexual nature of the generated subjects. Here's how the analysis works:
243
-
244
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
245
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
246
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
247
  4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
248
  5. **NFAA Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NFAA (not for all audiences).
249
-
250
  #### Visualization
251
-
252
  We create visual grids to represent the data:
253
-
254
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
255
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
256
  - **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
@@ -266,7 +229,8 @@ This demo provides an insightful look into how current text-to-image models hand
266
  "stabilityai/sdxl-turbo",
267
  "ByteDance/SDXL-Lightning",
268
  "stabilityai/stable-diffusion-2",
269
- "segmind/SSD-1B"
 
270
  ],
271
  value=default_model
272
  )
 
6
  StableDiffusionXLPipeline,
7
  EulerDiscreteScheduler,
8
  UNet2DConditionModel,
9
+ StableDiffusion3Pipeline,
10
+ FluxPipeline
11
  )
12
  from transformers import BlipProcessor, BlipForConditionalGeneration
13
  from pathlib import Path
 
22
 
23
  access_token = os.getenv("AccessTokenSD3")
24
 
 
25
  from huggingface_hub import login
26
  login(token = access_token)
27
 
 
28
  # Define model initialization functions
29
  def load_model(model_name):
30
  if model_name == "stabilityai/sdxl-turbo":
 
64
  scheduler=scheduler,
65
  torch_dtype=torch.float16
66
  ).to("cuda")
67
+ elif model_name == "black-forest-labs/FLUX.1-dev":
68
+ pipeline = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
69
+ pipeline.enable_model_cpu_offload()
70
  else:
71
  raise ValueError("Unknown model name")
72
  return pipeline
 
78
  @spaces.GPU
79
  def getimgen(prompt, model_name):
80
  if model_name == "stabilityai/sdxl-turbo":
81
+ return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2, height=512, width=512).images[0]
82
  elif model_name == "ByteDance/SDXL-Lightning":
83
+ return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0, height=512, width=512).images[0]
84
  elif model_name == "segmind/SSD-1B":
85
  neg_prompt = "ugly, blurry, poor quality"
86
+ return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt, height=512, width=512).images[0]
87
  elif model_name == "stabilityai/stable-diffusion-3-medium-diffusers":
88
+ return pipeline_text2image(prompt=prompt, negative_prompt="", num_inference_steps=28, guidance_scale=7.0, height=512, width=512).images[0]
89
  elif model_name == "stabilityai/stable-diffusion-2":
90
+ return pipeline_text2image(prompt=prompt, height=512, width=512).images[0]
91
+ elif model_name == "black-forest-labs/FLUX.1-dev":
92
+ return pipeline_text2image(
93
+ prompt,
94
+ height=512,
95
+ width=512,
96
+ guidance_scale=3.5,
97
+ num_inference_steps=50,
98
+ max_sequence_length=512,
99
+ generator=torch.Generator("cpu").manual_seed(0)
100
+ ).images[0]
101
 
102
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
103
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
142
  return fig
143
 
144
  def age_detector(image):
 
 
 
 
 
 
 
 
 
145
  pipe = pipeline('image-classification', model="dima806/faces_age_detection", device=0)
146
  result = pipe(image)
147
  max_score_item = max(result, key=lambda item: item['score'])
148
  return max_score_item['label']
149
 
150
  def ageplot(agelist):
 
 
 
 
 
 
 
 
 
151
  order = ["YOUNG", "MIDDLE", "OLD"]
152
  words = sorted(agelist, key=lambda x: order.index(x))
153
  colors = {"YOUNG": "skyblue", "MIDDLE": "royalblue", "OLD": "darkblue"}
 
160
  return fig
161
 
162
  def is_nsfw(image):
 
 
 
 
 
 
 
 
 
 
163
  classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")
164
  result = classifier(image)
165
  max_score_item = max(result, key=lambda item: item['score'])
166
  return max_score_item['label']
167
 
168
  def nsfwplot(nsfwlist):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  order = ["normal", "nsfw"]
170
  words = sorted(nsfwlist, key=lambda x: order.index(x))
171
  colors = {"normal": "mistyrose", "nsfw": "red"}
 
199
  except:
200
  skintones.append(None)
201
  genders.append(genderfromcaption(caption))
202
+ ages.append(age_detector(image))
203
+ nsfws.append(is_nsfw(image))
204
  return images, skintoneplot(skintones), genderplot(genders), ageplot(ages), nsfwplot(nsfws)
205
 
206
  with gr.Blocks(title="Demographic bias in Text-to-Image Generation Models") as demo:
207
  gr.Markdown("# Demographic bias in Text to Image Models")
208
  gr.Markdown('''
209
  In this demo, we explore the potential biases in text-to-image models by generating multiple images based on user prompts and analyzing the gender, skin tone, age, and potential sexual nature of the generated subjects. Here's how the analysis works:
 
210
  1. **Image Generation**: For each prompt, 10 images are generated using the selected model.
211
  2. **Gender Detection**: The [BLIP caption generator](https://huggingface.co/Salesforce/blip-image-captioning-large) is used to elicit gender markers by identifying words like "man," "boy," "woman," and "girl" in the captions.
212
  3. **Skin Tone Classification**: The [skin-tone-classifier library](https://github.com/ChenglongMa/SkinToneClassifier) is used to extract the skin tones of the generated subjects.
213
  4. **Age Detection**: The [Faces Age Detection model](https://huggingface.co/dima806/faces_age_detection) is used to identify the age of the generated subjects.
214
  5. **NFAA Detection**: The [Falconsai/nsfw_image_detection](https://huggingface.co/Falconsai/nsfw_image_detection) model is used to identify whether the generated images are NFAA (not for all audiences).
 
215
  #### Visualization
 
216
  We create visual grids to represent the data:
 
217
  - **Skin Tone Grids**: Skin tones are plotted as exact hex codes rather than using the Fitzpatrick scale, which can be [problematic and limiting for darker skin tones](https://arxiv.org/pdf/2309.05148).
218
  - **Gender Grids**: Light green denotes men, dark green denotes women, and grey denotes cases where the BLIP caption did not specify a binary gender.
219
  - **Age Grids**: Light blue denotes people between 18 and 30, blue denotes people between 30 and 50, and dark blue denotes people older than 50.
 
229
  "stabilityai/sdxl-turbo",
230
  "ByteDance/SDXL-Lightning",
231
  "stabilityai/stable-diffusion-2",
232
+ "segmind/SSD-1B",
233
+ "black-forest-labs/FLUX.1-dev"
234
  ],
235
  value=default_model
236
  )