czl commited on
Commit
98b6f69
1 Parent(s): 9827b70

added prompt interpolation demo

Browse files
Files changed (1) hide show
  1. app.py +202 -31
app.py CHANGED
@@ -3,11 +3,16 @@ import random
3
  import gradio as gr
4
  import numpy as np
5
  import torch
 
 
 
6
 
7
  from tools import synth
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
  model_path = "runwayml/stable-diffusion-v1-5"
 
 
11
 
12
  if torch.cuda.is_available():
13
  torch.cuda.max_memory_allocated(device=device)
@@ -29,79 +34,192 @@ MAX_IMAGE_SIZE = 1024
29
 
30
  def infer(
31
  input_image,
32
- prompt,
 
33
  negative_prompt,
34
  seed,
35
  randomize_seed,
36
  width,
37
  height,
38
  guidance_scale,
 
39
  num_inference_steps,
 
 
 
40
  ):
 
41
 
42
  if randomize_seed:
43
  seed = random.randint(0, MAX_SEED)
44
-
45
  generator = torch.Generator().manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  image = pipe(
48
- prompt=prompt,
49
- negative_prompt=negative_prompt,
50
- guidance_scale=guidance_scale,
51
- num_inference_steps=num_inference_steps,
52
- width=width,
53
  height=height,
 
 
 
 
 
 
54
  generator=generator,
55
- image=input_image,
 
56
  ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- return image
59
 
60
 
61
- examples = [
62
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
63
- "An astronaut riding a green horse",
64
- "A delicious ceviche cheesecake slice",
 
 
 
65
  ]
66
 
67
  css = """
68
  #col-container {
69
  margin: 0 auto;
70
- max-width: 520px;
71
  }
72
  """
73
 
 
 
 
 
 
 
 
74
  if torch.cuda.is_available():
75
  power_device = "GPU"
76
  else:
77
  power_device = "CPU"
78
 
79
- with gr.Blocks(css=css) as demo:
80
 
81
  with gr.Column(elem_id="col-container"):
82
  gr.Markdown(
83
  f"""
84
- # Text-to-Image Gradio Template
85
  Currently running on {power_device}.
86
  """
87
  )
88
 
89
- with gr.Row():
90
 
91
- prompt = gr.Text(
92
- label="Prompt",
93
- show_label=False,
 
94
  max_lines=1,
95
- placeholder="Enter your prompt",
96
  container=False,
97
  )
98
- input_image = gr.Image(type="pil", label="Input Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  run_button = gr.Button("Run", scale=0)
101
 
102
  result = gr.Image(label="Result", show_label=False)
103
 
104
- with gr.Accordion("Advanced Settings", open=False):
105
 
106
  negative_prompt = gr.Text(
107
  label="Negative prompt",
@@ -120,6 +238,15 @@ with gr.Blocks(css=css) as demo:
120
 
121
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
122
 
 
 
 
 
 
 
 
 
 
123
  with gr.Row():
124
 
125
  width = gr.Slider(
@@ -145,33 +272,77 @@ with gr.Blocks(css=css) as demo:
145
  minimum=0.0,
146
  maximum=10.0,
147
  step=0.1,
148
- value=0.0,
149
  )
150
 
151
  num_inference_steps = gr.Slider(
152
  label="Number of inference steps",
153
  minimum=1,
154
- maximum=12,
155
  step=1,
156
- value=2,
157
  )
158
-
159
- gr.Examples(examples=examples, inputs=[prompt])
160
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  run_button.click(
162
  fn=infer,
163
  inputs=[
164
  input_image,
165
- prompt,
 
166
  negative_prompt,
167
  seed,
168
  randomize_seed,
169
  width,
170
  height,
171
  guidance_scale,
 
172
  num_inference_steps,
 
 
 
173
  ],
174
- outputs=[result],
175
  )
176
 
177
  demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ import torchvision.transforms as transforms
7
+ from torchmetrics.functional.image import structural_similarity_index_measure as ssim
8
+ from transformers import CLIPModel, CLIPProcessor
9
 
10
  from tools import synth
11
 
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  model_path = "runwayml/stable-diffusion-v1-5"
14
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
15
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
16
 
17
  if torch.cuda.is_available():
18
  torch.cuda.max_memory_allocated(device=device)
 
34
 
35
  def infer(
36
  input_image,
37
+ prompt1,
38
+ prompt2,
39
  negative_prompt,
40
  seed,
41
  randomize_seed,
42
  width,
43
  height,
44
  guidance_scale,
45
+ interpolation_step,
46
  num_inference_steps,
47
+ num_interpolation_steps,
48
+ sample_mid_interpolation,
49
+ remove_n_middle,
50
  ):
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
 
53
  if randomize_seed:
54
  seed = random.randint(0, MAX_SEED)
55
+ prompts = [prompt1, prompt2]
56
  generator = torch.Generator().manual_seed(seed)
57
+ print(seed)
58
+ interpolated_prompt_embeds, prompt_metadata = synth.interpolatePrompts(
59
+ prompts,
60
+ pipe,
61
+ num_interpolation_steps,
62
+ sample_mid_interpolation,
63
+ remove_n_middle=remove_n_middle,
64
+ device=device,
65
+ )
66
+ negative_prompts = [negative_prompt, negative_prompt]
67
+ if negative_prompts != ["", ""]:
68
+ interpolated_negative_prompts_embeds, negative_prompt_metadata = (
69
+ synth.interpolatePrompts(
70
+ negative_prompts,
71
+ pipe,
72
+ num_interpolation_steps,
73
+ sample_mid_interpolation,
74
+ remove_n_middle=remove_n_middle,
75
+ device=device,
76
+ )
77
+ )
78
+ else:
79
+ interpolated_negative_prompts_embeds, negative_prompt_metadata = [None] * len(
80
+ interpolated_prompt_embeds
81
+ ), None
82
 
83
+ latents = torch.randn(
84
+ (1, pipe.unet.config.in_channels, height // 8, width // 8),
85
+ generator=generator,
86
+ ).to(device)
87
+ embed_pairs = zip(interpolated_prompt_embeds, interpolated_negative_prompts_embeds)
88
+ embed_pairs_list = list(embed_pairs)
89
+ print(len(embed_pairs_list))
90
+ # offset step by -1
91
+ prompt_embeds, negative_prompt_embeds = embed_pairs_list[interpolation_step - 1]
92
+ preprocess_input = transforms.Compose(
93
+ [transforms.ToTensor(), transforms.Resize((512, 512))]
94
+ )
95
+ input_img_tensor = preprocess_input(input_image).unsqueeze(0)
96
+ if negative_prompt_embeds is not None:
97
+ npe = negative_prompt_embeds[None, ...]
98
+ else:
99
+ npe = None
100
  image = pipe(
 
 
 
 
 
101
  height=height,
102
+ width=width,
103
+ num_images_per_prompt=1,
104
+ prompt_embeds=prompt_embeds[None, ...],
105
+ negative_prompt_embeds=npe,
106
+ num_inference_steps=num_inference_steps,
107
+ guidance_scale=guidance_scale,
108
  generator=generator,
109
+ latents=latents,
110
+ image=input_img_tensor,
111
  ).images[0]
112
+ pred_image = transforms.ToTensor()(image).unsqueeze(0)
113
+ ssim_score = ssim(pred_image, input_img_tensor).item()
114
+ real_inputs = clip_processor(
115
+ text=prompts, padding=True, images=input_image, return_tensors="pt"
116
+ ).to(device)
117
+ real_output = clip_model(**real_inputs)
118
+ synth_inputs = clip_processor(
119
+ text=prompts, padding=True, images=image, return_tensors="pt"
120
+ ).to(device)
121
+ synth_output = clip_model(**synth_inputs)
122
+ cos_sim = torch.nn.CosineSimilarity(dim=1)
123
+ cosine_sim = (
124
+ cos_sim(real_output.image_embeds, synth_output.image_embeds)
125
+ .detach()
126
+ .cpu()
127
+ .numpy()
128
+ .squeeze()
129
+ * 100
130
+ )
131
 
132
+ return image, seed, ssim_score, cosine_sim
133
 
134
 
135
+ examples1 = [
136
+ "A photo of a chain saw, chainsaw",
137
+ "A photo of a Shih-Tzu, a type of dog",
138
+ ]
139
+ examples2 = [
140
+ "A photo of a golf ball",
141
+ "A photo of a beagle, a type of dog",
142
  ]
143
 
144
  css = """
145
  #col-container {
146
  margin: 0 auto;
 
147
  }
148
  """
149
 
150
+
151
+ def update_steps(total_steps, interpolation_step):
152
+ if interpolation_step > total_steps:
153
+ return gr.update(maximum=total_steps // 2, value=total_steps)
154
+ return gr.update(maximum=total_steps // 2)
155
+
156
+
157
  if torch.cuda.is_available():
158
  power_device = "GPU"
159
  else:
160
  power_device = "CPU"
161
 
162
+ with gr.Blocks(css=css, title="Generative Date Augmentation") as demo:
163
 
164
  with gr.Column(elem_id="col-container"):
165
  gr.Markdown(
166
  f"""
167
+ # Data Augmentation with Image-to-Image Diffusion Models via Prompt Interpolation
168
  Currently running on {power_device}.
169
  """
170
  )
171
 
172
+ input_image = gr.Image(type="pil", label="Image to Augment")
173
 
174
+ with gr.Row():
175
+ prompt1 = gr.Text(
176
+ label="Prompt 1",
177
+ show_label=True,
178
  max_lines=1,
179
+ placeholder="Enter your first prompt",
180
  container=False,
181
  )
182
+ with gr.Row():
183
+ prompt2 = gr.Text(
184
+ label="Prompt 2",
185
+ show_label=True,
186
+ max_lines=1,
187
+ placeholder="Enter your second prompt",
188
+ container=False,
189
+ )
190
+ with gr.Row():
191
+ gr.Examples(
192
+ examples=examples1, inputs=[prompt1], label="Example for Prompt 1"
193
+ )
194
+ gr.Examples(
195
+ examples=examples2, inputs=[prompt2], label="Example for Prompt 2"
196
+ )
197
 
198
+ with gr.Row():
199
+ num_interpolation_steps = gr.Slider(
200
+ label="Total interpolation steps",
201
+ minimum=2,
202
+ maximum=32,
203
+ step=2,
204
+ value=16,
205
+ )
206
+ interpolation_step = gr.Slider(
207
+ label="Specific Interpolation Step",
208
+ minimum=1,
209
+ maximum=8,
210
+ step=1,
211
+ value=8,
212
+ )
213
+ num_interpolation_steps.change(
214
+ fn=update_steps,
215
+ inputs=[num_interpolation_steps, interpolation_step],
216
+ outputs=[interpolation_step],
217
+ )
218
  run_button = gr.Button("Run", scale=0)
219
 
220
  result = gr.Image(label="Result", show_label=False)
221
 
222
+ with gr.Accordion("Advanced Settings", open=True):
223
 
224
  negative_prompt = gr.Text(
225
  label="Negative prompt",
 
238
 
239
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
240
 
241
+ gr.Markdown("Negative Prompt: ")
242
+ with gr.Row():
243
+ negative_prompt = gr.Text(
244
+ label="Negative Prompt",
245
+ show_label=True,
246
+ max_lines=1,
247
+ value="blurry image, disfigured, deformed, distorted, cartoon, drawings",
248
+ container=False,
249
+ )
250
  with gr.Row():
251
 
252
  width = gr.Slider(
 
272
  minimum=0.0,
273
  maximum=10.0,
274
  step=0.1,
275
+ value=8.0,
276
  )
277
 
278
  num_inference_steps = gr.Slider(
279
  label="Number of inference steps",
280
  minimum=1,
281
+ maximum=80,
282
  step=1,
283
+ value=25,
284
  )
285
+ with gr.Row():
286
+ sample_mid_interpolation = gr.Slider(
287
+ label="Number of sampling steps in the middle of interpolation",
288
+ minimum=2,
289
+ maximum=80,
290
+ step=2,
291
+ value=16,
292
+ )
293
+ with gr.Row():
294
+ remove_n_middle = gr.Slider(
295
+ label="Number of middle steps to remove from interpolation",
296
+ minimum=0,
297
+ maximum=80,
298
+ step=2,
299
+ value=0,
300
+ )
301
+ gr.Markdown(
302
+ """
303
+ Metadata:
304
+ """
305
+ )
306
+ with gr.Row():
307
+ show_seed = gr.Label(label="Seed:", value="Randomized seed")
308
+ ssim_score = gr.Label(label="SSIM Score:", value="Generate to see score")
309
+ cos_sim = gr.Label(label="CLIP Score:", value="Generate to see score")
310
  run_button.click(
311
  fn=infer,
312
  inputs=[
313
  input_image,
314
+ prompt1,
315
+ prompt2,
316
  negative_prompt,
317
  seed,
318
  randomize_seed,
319
  width,
320
  height,
321
  guidance_scale,
322
+ interpolation_step,
323
  num_inference_steps,
324
+ num_interpolation_steps,
325
+ sample_mid_interpolation,
326
+ remove_n_middle,
327
  ],
328
+ outputs=[result, show_seed, ssim_score, cos_sim],
329
  )
330
 
331
  demo.queue().launch()
332
+
333
+ """
334
+ input_image,
335
+ prompt1,
336
+ prompt2,
337
+ negative_prompt,
338
+ seed,
339
+ randomize_seed,
340
+ width,
341
+ height,
342
+ guidance_scale,
343
+ interpolation_step,
344
+ num_inference_steps,
345
+ num_interpolation_steps,
346
+ sample_mid_interpolation,
347
+ remove_n_middle,
348
+ """