czl commited on
Commit
38d4034
1 Parent(s): bd14b9f

simplify UI

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.JPEG filter=lfs diff=lfs merge=lfs -text
37
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -9,4 +9,50 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  pinned: false
10
  ---
11
 
12
+ # Generative Data Augmentation Demo
13
+
14
+ Main GitHub Repo: [Generative Data Augmentation](https://github.com/zhulinchng/generative-data-augmentation) | Image Classification Demo: [Generative Augmented Classifiers](https://huggingface.co/spaces/czl/generative-augmented-classifiers).
15
+
16
+ This demo is created as part of the 'Investigating the Effectiveness of Generative Diffusion Models in Synthesizing Images for Data Augmentation in Image Classification' dissertation.
17
+
18
+ The user can augment an image by interpolating between two prompts, and specify the number of interpolation steps and the specific step to generate the image.
19
+
20
+ ## Demo Usage Instructions
21
+
22
+ 1. Upload an image.
23
+ 2. Enter the two prompts to interpolate between, the first prompt should contain the desired class of the augmented image, the second prompt should contain the undesired class (i.e., confusing class).
24
+
25
+ ## Configuration
26
+
27
+ - Total Interpolation Steps: The number of steps to interpolate between the two prompts.
28
+ - Interpolation Step: The specific step to generate the image.
29
+ - Example for 10 steps:
30
+
31
+ ```python
32
+ Total: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
33
+ Sampled: 4
34
+ ```
35
+
36
+ - Seed: Seed value for reproducibility.
37
+ - Negative Prompt: Prompt to guide the model away from generating the image.
38
+ - Width, Height: The dimensions of the generated image.
39
+ - Guidance Scale: The scale of the guide the model on how closely to follow the prompts.
40
+
41
+ ## Metadata
42
+
43
+ [SSIM Score](https://lightning.ai/docs/torchmetrics/stable/image/structural_similarity.html): Structural Similarity Index (SSIM) score between the original and generated image, ranges from 0 to 1.
44
+ [CLIP Score](https://lightning.ai/docs/torchmetrics/stable/multimodal/clip_score.html): CLIP similarity score between the original and generated image, ranges from 0 to 100.
45
+
46
+ ## Local Setup
47
+
48
+ ```bash
49
+ git clone https://huggingface.co/spaces/czl/generative-data-augmentation-demo
50
+ cd generative-data-augmentation-demo
51
+ # Setup the data directory structure as shown above
52
+ conda create --name $env_name python=3.11.* # Replace $env_name with your environment name
53
+ conda activate $env_name
54
+ # Visit PyTorch website https://pytorch.org/get-started/previous-versions/#v212 for PyTorch installation instructions.
55
+ pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url # Obtain the correct URL from the PyTorch website
56
+ pip install -r requirements.txt
57
+ python app.py
58
+ ```
app.py CHANGED
@@ -4,6 +4,7 @@ import gradio as gr
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as transforms
 
7
  from torchmetrics.functional.image import structural_similarity_index_measure as ssim
8
  from transformers import CLIPModel, CLIPProcessor
9
 
@@ -45,8 +46,6 @@ def infer(
45
  interpolation_step,
46
  num_inference_steps,
47
  num_interpolation_steps,
48
- sample_mid_interpolation,
49
- remove_n_middle,
50
  ):
51
  device = "cuda" if torch.cuda.is_available() else "cpu"
52
 
@@ -55,38 +54,15 @@ def infer(
55
  assert num_interpolation_steps % 2 == 0
56
  except AssertionError:
57
  raise ValueError("num_interpolation_steps must be an even number")
58
- try:
59
- assert sample_mid_interpolation % 2 == 0
60
- except AssertionError:
61
- raise ValueError("sample_mid_interpolation must be an even number")
62
- try:
63
- assert remove_n_middle % 2 == 0
64
- except AssertionError:
65
- raise ValueError("remove_n_middle must be an even number")
66
- try:
67
- assert num_interpolation_steps >= sample_mid_interpolation
68
- except AssertionError:
69
- raise ValueError(
70
- "num_interpolation_steps must be greater than or equal to sample_mid_interpolation"
71
- )
72
- try:
73
- assert num_interpolation_steps >= 2 and sample_mid_interpolation >= 2
74
- except AssertionError:
75
- raise ValueError(
76
- "num_interpolation_steps and sample_mid_interpolation must be greater than or equal to 2"
77
- )
78
- try:
79
- assert sample_mid_interpolation - remove_n_middle >= 2
80
- except AssertionError:
81
- raise ValueError(
82
- "sample_mid_interpolation must be greater than or equal to remove_n_middle + 2"
83
- )
84
 
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
  prompts = [prompt1, prompt2]
88
  generator = torch.Generator().manual_seed(seed)
89
 
 
 
 
90
  interpolated_prompt_embeds, prompt_metadata = synth.interpolatePrompts(
91
  prompts,
92
  pipe,
@@ -116,7 +92,6 @@ def infer(
116
  ).to(device)
117
  embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
118
  embed_pairs_list = list(embed_pairs)
119
- print(len(embed_pairs_list))
120
  # offset step by -1
121
  prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolation_step - 1]
122
  preprocess_input = transforms.Compose(
@@ -127,7 +102,7 @@ def infer(
127
  npe = negative_prompt_embeds[None, ...]
128
  else:
129
  npe = None
130
- image = pipe(
131
  height=height,
132
  width=width,
133
  num_images_per_prompt=1,
@@ -138,7 +113,13 @@ def infer(
138
  generator=generator,
139
  latents=latents,
140
  image=input_img_tensor,
141
- ).images[0]
 
 
 
 
 
 
142
  pred_image = transforms.ToTensor()(image).unsqueeze(0)
143
  ssim_score = ssim(pred_image, input_img_tensor).item()
144
  real_inputs = clip_processor(
@@ -163,25 +144,17 @@ def infer(
163
 
164
 
165
  examples1 = [
166
- "A photo of a chain saw, chainsaw",
167
  "A photo of a Shih-Tzu, a type of dog",
168
  ]
169
  examples2 = [
170
- "A photo of a golf ball",
171
  "A photo of a beagle, a type of dog",
172
  ]
173
 
174
 
175
  def update_steps(total_steps, interpolation_step):
176
- if interpolation_step > total_steps:
177
- return gr.update(maximum=total_steps // 2, value=total_steps)
178
- return gr.update(maximum=total_steps // 2)
179
-
180
-
181
- def update_sampling_steps(total_steps, sample_steps):
182
- # if sample_steps > total_steps:
183
- # return gr.update(value=total_steps)
184
- return gr.update(value=total_steps)
185
 
186
 
187
  def update_format(image_format):
@@ -211,7 +184,7 @@ with gr.Blocks(title="Generative Date Augmentation Demo") as demo:
211
  label="Prompt for the image to synthesize. (Actual class)",
212
  show_label=True,
213
  max_lines=1,
214
- placeholder="Enter your first prompt",
215
  container=False,
216
  )
217
  with gr.Row():
@@ -219,32 +192,44 @@ with gr.Blocks(title="Generative Date Augmentation Demo") as demo:
219
  label="Prompt to augment against. (Confusing class)",
220
  show_label=True,
221
  max_lines=1,
222
- placeholder="Enter your second prompt",
223
  container=False,
224
  )
225
  with gr.Row():
226
  gr.Examples(
227
- examples=examples1, inputs=[prompt1], label="Example for Prompt 1"
 
 
 
 
 
 
 
 
 
 
228
  )
229
  gr.Examples(
230
- examples=examples2, inputs=[prompt2], label="Example for Prompt 2"
 
 
231
  )
232
 
233
  with gr.Row():
234
- interpolation_step = gr.Slider(
235
- label="Specific Interpolation Step",
236
- minimum=1,
237
- maximum=8,
238
- step=1,
239
- value=8,
240
- )
241
  num_interpolation_steps = gr.Slider(
242
- label="Total interpolation steps",
243
  minimum=2,
244
- maximum=32,
245
  step=2,
246
  value=16,
247
  )
 
 
 
 
 
 
 
248
  num_interpolation_steps.change(
249
  fn=update_steps,
250
  inputs=[num_interpolation_steps, interpolation_step],
@@ -305,27 +290,6 @@ with gr.Blocks(title="Generative Date Augmentation Demo") as demo:
305
  step=1,
306
  value=25,
307
  )
308
- with gr.Row():
309
- sample_mid_interpolation = gr.Slider(
310
- label="Number of sampling steps in the middle of interpolation",
311
- minimum=2,
312
- maximum=80,
313
- step=2,
314
- value=16,
315
- )
316
- num_interpolation_steps.change(
317
- fn=update_sampling_steps,
318
- inputs=[num_interpolation_steps, sample_mid_interpolation],
319
- outputs=[sample_mid_interpolation],
320
- )
321
- with gr.Row():
322
- remove_n_middle = gr.Slider(
323
- label="Number of middle steps to remove from interpolation",
324
- minimum=0,
325
- maximum=80,
326
- step=2,
327
- value=0,
328
- )
329
  with gr.Row():
330
  image_type = gr.Radio(
331
  choices=[
@@ -372,6 +336,10 @@ Note: Running on CPU will take longer (approx. 6 minutes with default settings).
372
  This demo is created as part of the 'Investigating the Effectiveness of Generative Diffusion Models in Synthesizing Images for Data Augmentation in Image Classification' dissertation.
373
 
374
  The user can augment an image by interpolating between two prompts, and specify the number of interpolation steps and the specific step to generate the image.
 
 
 
 
375
  """
376
  )
377
  run_button.click(
@@ -389,27 +357,8 @@ The user can augment an image by interpolating between two prompts, and specify
389
  interpolation_step,
390
  num_inference_steps,
391
  num_interpolation_steps,
392
- sample_mid_interpolation,
393
- remove_n_middle,
394
  ],
395
  outputs=[result, show_seed, ssim_score, cos_sim],
396
  )
397
 
398
  demo.queue().launch(show_error=True)
399
-
400
- """
401
- input_image,
402
- prompt1,
403
- prompt2,
404
- negative_prompt,
405
- seed,
406
- randomize_seed,
407
- width,
408
- height,
409
- guidance_scale,
410
- interpolation_step,
411
- num_inference_steps,
412
- num_interpolation_steps,
413
- sample_mid_interpolation,
414
- remove_n_middle,
415
- """
 
4
  import numpy as np
5
  import torch
6
  import torchvision.transforms as transforms
7
+ from PIL import Image
8
  from torchmetrics.functional.image import structural_similarity_index_measure as ssim
9
  from transformers import CLIPModel, CLIPProcessor
10
 
 
46
  interpolation_step,
47
  num_inference_steps,
48
  num_interpolation_steps,
 
 
49
  ):
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
 
 
54
  assert num_interpolation_steps % 2 == 0
55
  except AssertionError:
56
  raise ValueError("num_interpolation_steps must be an even number")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  if randomize_seed:
59
  seed = random.randint(0, MAX_SEED)
60
  prompts = [prompt1, prompt2]
61
  generator = torch.Generator().manual_seed(seed)
62
 
63
+ sample_mid_interpolation = num_interpolation_steps
64
+ remove_n_middle = 0
65
+
66
  interpolated_prompt_embeds, prompt_metadata = synth.interpolatePrompts(
67
  prompts,
68
  pipe,
 
92
  ).to(device)
93
  embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
94
  embed_pairs_list = list(embed_pairs)
 
95
  # offset step by -1
96
  prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolation_step - 1]
97
  preprocess_input = transforms.Compose(
 
102
  npe = negative_prompt_embeds[None, ...]
103
  else:
104
  npe = None
105
+ images_list = pipe(
106
  height=height,
107
  width=width,
108
  num_images_per_prompt=1,
 
113
  generator=generator,
114
  latents=latents,
115
  image=input_img_tensor,
116
+ )
117
+ if images_list["nsfw_content_detected"][0]:
118
+ image = Image.open("samples/unsafe.jpeg")
119
+ return image, seed, "Unsafe content detected", "Unsafe content detected"
120
+ else:
121
+ image = images_list.images[0]
122
+
123
  pred_image = transforms.ToTensor()(image).unsqueeze(0)
124
  ssim_score = ssim(pred_image, input_img_tensor).item()
125
  real_inputs = clip_processor(
 
144
 
145
 
146
  examples1 = [
147
+ "A photo of a garbage truck, dustcart",
148
  "A photo of a Shih-Tzu, a type of dog",
149
  ]
150
  examples2 = [
151
+ "A photo of a cassette player",
152
  "A photo of a beagle, a type of dog",
153
  ]
154
 
155
 
156
  def update_steps(total_steps, interpolation_step):
157
+ return gr.update(maximum=total_steps)
 
 
 
 
 
 
 
 
158
 
159
 
160
  def update_format(image_format):
 
184
  label="Prompt for the image to synthesize. (Actual class)",
185
  show_label=True,
186
  max_lines=1,
187
+ placeholder="Enter Prompt for the image to synthesize. (Actual class)",
188
  container=False,
189
  )
190
  with gr.Row():
 
192
  label="Prompt to augment against. (Confusing class)",
193
  show_label=True,
194
  max_lines=1,
195
+ placeholder="Enter Prompt to augment against. (Confusing class)",
196
  container=False,
197
  )
198
  with gr.Row():
199
  gr.Examples(
200
+ examples=[
201
+ "samples/n03417042_5234.JPEG",
202
+ "samples/n02086240_2799.JPEG",
203
+ ],
204
+ inputs=[input_image],
205
+ label="Example Images",
206
+ )
207
+ gr.Examples(
208
+ examples=examples1,
209
+ inputs=[prompt1],
210
+ label="Example for Prompt 1 (Actual class)",
211
  )
212
  gr.Examples(
213
+ examples=examples2,
214
+ inputs=[prompt2],
215
+ label="Example for Prompt 2 (Confusing class)",
216
  )
217
 
218
  with gr.Row():
 
 
 
 
 
 
 
219
  num_interpolation_steps = gr.Slider(
220
+ label="Total Interpolation Steps",
221
  minimum=2,
222
+ maximum=128,
223
  step=2,
224
  value=16,
225
  )
226
+ interpolation_step = gr.Slider(
227
+ label="Sample Interpolation Step",
228
+ minimum=1,
229
+ maximum=16,
230
+ step=1,
231
+ value=8,
232
+ )
233
  num_interpolation_steps.change(
234
  fn=update_steps,
235
  inputs=[num_interpolation_steps, interpolation_step],
 
290
  step=1,
291
  value=25,
292
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  with gr.Row():
294
  image_type = gr.Radio(
295
  choices=[
 
336
  This demo is created as part of the 'Investigating the Effectiveness of Generative Diffusion Models in Synthesizing Images for Data Augmentation in Image Classification' dissertation.
337
 
338
  The user can augment an image by interpolating between two prompts, and specify the number of interpolation steps and the specific step to generate the image.
339
+
340
+ View the files used in this demo [here](https://huggingface.co/spaces/czl/generative-data-augmentation-demo/tree/main).
341
+
342
+ Note: Safety checker is enabled to prevent unsafe content from being displayed in this public demo.
343
  """
344
  )
345
  run_button.click(
 
357
  interpolation_step,
358
  num_inference_steps,
359
  num_interpolation_steps,
 
 
360
  ],
361
  outputs=[result, show_seed, ssim_score, cos_sim],
362
  )
363
 
364
  demo.queue().launch(show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
samples/n02086240_2799.JPEG ADDED

Git LFS Details

  • SHA256: ea7a31e982240ada8d2a45827be298f68c72bbef4a95f692653cca70d9d12c20
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
samples/n03417042_5234.JPEG ADDED

Git LFS Details

  • SHA256: 7133394239f22a329d0c4ae568cc0aad4992e5a567a272e84257c2b0356e0fbd
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
samples/unsafe.jpeg ADDED

Git LFS Details

  • SHA256: 35e213ae5ad722e6c7bf6ee24c491a8a34ee0148d329b76b6e452871413917c4
  • Pointer size: 130 Bytes
  • Size of remote file: 41.6 kB
tools/synth.py CHANGED
@@ -157,7 +157,6 @@ def pipe_img(
157
  scheduler=scheduler,
158
  torch_dtype=torch.float32,
159
  use_safetensors=use_safetensors,
160
- safety_checker=None,
161
  ).to(device)
162
  if cpu_offload:
163
  pipe.enable_model_cpu_offload()
 
157
  scheduler=scheduler,
158
  torch_dtype=torch.float32,
159
  use_safetensors=use_safetensors,
 
160
  ).to(device)
161
  if cpu_offload:
162
  pipe.enable_model_cpu_offload()