fffiloni commited on
Commit
a4737a3
1 Parent(s): 097203a

Upload 12 files

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ashleigh Watson and Alex Nasa
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore")
3
+ from diffusers import StableDiffusionPipeline, DDIMInverseScheduler, DDIMScheduler
4
+ import torch
5
+ from typing import Optional
6
+ from tqdm import tqdm
7
+ from diffusers.models.attention_processor import Attention, AttnProcessor2_0
8
+ import torchvision
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import gc
12
+ import gradio as gr
13
+ import numpy as np
14
+ import os
15
+ import pickle
16
+ from transformers import CLIPImageProcessor
17
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
18
+ import argparse
19
+
20
+ weights = {
21
+ 'down': {
22
+ 4096: 0.0,
23
+ 1024: 1.0,
24
+ 256: 1.0,
25
+ },
26
+ 'mid': {
27
+ 64: 1.0,
28
+ },
29
+ 'up': {
30
+ 256: 1.0,
31
+ 1024: 1.0,
32
+ 4096: 0.0,
33
+ }
34
+ }
35
+ num_inference_steps = 10
36
+ model_id = "stabilityai/stable-diffusion-2-1-base"
37
+
38
+ pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cuda")
39
+ inverse_scheduler = DDIMInverseScheduler.from_pretrained(model_id, subfolder="scheduler")
40
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
41
+
42
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to("cuda")
43
+ feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
44
+
45
+ should_stop = False
46
+
47
+ def save_state_to_file(state):
48
+ filename = "state.pkl"
49
+ with open(filename, 'wb') as f:
50
+ pickle.dump(state, f)
51
+ return filename
52
+
53
+ def load_state_from_file(filename):
54
+ with open(filename, 'rb') as f:
55
+ state = pickle.load(f)
56
+ return state
57
+
58
+ def stop_reconstruct():
59
+ global should_stop
60
+ should_stop = True
61
+
62
+ def reconstruct(input_img, caption):
63
+
64
+ img = input_img
65
+
66
+ cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
67
+ uncond_prompt_embeds = pipe.encode_prompt(prompt="", device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
68
+
69
+ prompt_embeds_combined = torch.cat([uncond_prompt_embeds, cond_prompt_embeds])
70
+
71
+
72
+ transform = torchvision.transforms.Compose([
73
+ torchvision.transforms.Resize((512, 512)),
74
+ torchvision.transforms.ToTensor()
75
+ ])
76
+
77
+ loaded_image = transform(img).to("cuda").unsqueeze(0)
78
+
79
+ if loaded_image.shape[1] == 4:
80
+ loaded_image = loaded_image[:,:3,:,:]
81
+
82
+ with torch.no_grad():
83
+ encoded_image = pipe.vae.encode(loaded_image*2 - 1)
84
+ real_image_latents = pipe.vae.config.scaling_factor * encoded_image.latent_dist.sample()
85
+
86
+ guidance_scale = 1
87
+ inverse_scheduler.set_timesteps(num_inference_steps, device="cuda")
88
+ timesteps = inverse_scheduler.timesteps
89
+
90
+ latents = real_image_latents
91
+
92
+ inversed_latents = []
93
+
94
+ with torch.no_grad():
95
+
96
+ replace_attention_processor(pipe.unet, True)
97
+
98
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
99
+
100
+ inversed_latents.append(latents)
101
+
102
+ latent_model_input = torch.cat([latents] * 2)
103
+
104
+ noise_pred = pipe.unet(
105
+ latent_model_input,
106
+ t,
107
+ encoder_hidden_states=prompt_embeds_combined,
108
+ cross_attention_kwargs=None,
109
+ return_dict=False,
110
+ )[0]
111
+
112
+
113
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
114
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
115
+
116
+ latents = inverse_scheduler.step(noise_pred, t, latents, return_dict=False)[0]
117
+
118
+
119
+ # initial state
120
+ real_image_initial_latents = latents
121
+
122
+ W_values = uncond_prompt_embeds.repeat(num_inference_steps, 1, 1)
123
+ QT = nn.Parameter(W_values.clone())
124
+
125
+
126
+ guidance_scale = 7.5
127
+ scheduler.set_timesteps(num_inference_steps, device="cuda")
128
+ timesteps = scheduler.timesteps
129
+
130
+ optimizer = torch.optim.AdamW([QT], lr=0.008)
131
+
132
+ pipe.vae.eval()
133
+ pipe.vae.requires_grad_(False)
134
+ pipe.unet.eval()
135
+ pipe.unet.requires_grad_(False)
136
+
137
+ last_loss = 1
138
+
139
+ for epoch in range(50):
140
+ gc.collect()
141
+ torch.cuda.empty_cache()
142
+
143
+ if last_loss < 0.02:
144
+ break
145
+ elif last_loss < 0.03:
146
+ for param_group in optimizer.param_groups:
147
+ param_group['lr'] = 0.003
148
+ elif last_loss < 0.035:
149
+ for param_group in optimizer.param_groups:
150
+ param_group['lr'] = 0.006
151
+
152
+ intermediate_values = real_image_initial_latents.clone()
153
+
154
+
155
+ for i in range(num_inference_steps):
156
+ latents = intermediate_values.detach().clone()
157
+
158
+ t = timesteps[i]
159
+
160
+ prompt_embeds = torch.cat([QT[i].unsqueeze(0), cond_prompt_embeds.detach()])
161
+
162
+ latent_model_input = torch.cat([latents] * 2)
163
+
164
+ noise_pred_model = pipe.unet(
165
+ latent_model_input,
166
+ t,
167
+ encoder_hidden_states=prompt_embeds,
168
+ cross_attention_kwargs=None,
169
+ return_dict=False,
170
+ )[0]
171
+
172
+ noise_pred_uncond, noise_pred_text = noise_pred_model.chunk(2)
173
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
174
+
175
+ intermediate_values = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
176
+
177
+
178
+ loss = F.mse_loss(inversed_latents[len(timesteps) - 1 - i].detach(), intermediate_values, reduction="mean")
179
+ last_loss = loss
180
+
181
+ optimizer.zero_grad()
182
+ loss.backward()
183
+ optimizer.step()
184
+
185
+ global should_stop
186
+ if should_stop:
187
+ should_stop = False
188
+ break
189
+
190
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
191
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
192
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
193
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
194
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
195
+ image_np = (image_np * 255).astype(np.uint8)
196
+
197
+ yield image_np, caption, [caption, real_image_initial_latents, QT]
198
+
199
+ image = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
200
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
201
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
202
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
203
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
204
+ image_np = (image_np * 255).astype(np.uint8)
205
+
206
+ yield image_np, caption, [caption, real_image_initial_latents, QT]
207
+
208
+
209
+ class AttnReplaceProcessor(AttnProcessor2_0):
210
+
211
+ def __init__(self, replace_all, weight):
212
+ super().__init__()
213
+ self.replace_all = replace_all
214
+ self.weight = weight
215
+
216
+ def __call__(
217
+ self,
218
+ attn: Attention,
219
+ hidden_states: torch.FloatTensor,
220
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
221
+ attention_mask: Optional[torch.FloatTensor] = None,
222
+ temb: Optional[torch.FloatTensor] = None,
223
+ *args,
224
+ **kwargs,
225
+ ) -> torch.FloatTensor:
226
+
227
+ residual = hidden_states
228
+
229
+ is_cross = not encoder_hidden_states is None
230
+
231
+ input_ndim = hidden_states.ndim
232
+
233
+ if input_ndim == 4:
234
+ batch_size, channel, height, width = hidden_states.shape
235
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
236
+
237
+ batch_size, _, _ = (
238
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
239
+ )
240
+
241
+ if attn.group_norm is not None:
242
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
243
+
244
+ query = attn.to_q(hidden_states)
245
+
246
+ if encoder_hidden_states is None:
247
+ encoder_hidden_states = hidden_states
248
+ elif attn.norm_cross:
249
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
250
+
251
+ key = attn.to_k(encoder_hidden_states)
252
+ value = attn.to_v(encoder_hidden_states)
253
+
254
+ query = attn.head_to_batch_dim(query)
255
+ key = attn.head_to_batch_dim(key)
256
+ value = attn.head_to_batch_dim(value)
257
+
258
+ attention_scores = attn.scale * torch.bmm(query, key.transpose(-1, -2))
259
+
260
+ dimension_squared = hidden_states.shape[1]
261
+
262
+ if not is_cross and (self.replace_all):
263
+ ucond_attn_scores_src, ucond_attn_scores_dst, attn_scores_src, attn_scores_dst = attention_scores.chunk(4)
264
+ attn_scores_dst.copy_(self.weight[dimension_squared] * attn_scores_src + (1.0 - self.weight[dimension_squared]) * attn_scores_dst)
265
+ ucond_attn_scores_dst.copy_(self.weight[dimension_squared] * ucond_attn_scores_src + (1.0 - self.weight[dimension_squared]) * ucond_attn_scores_dst)
266
+
267
+ attention_probs = attention_scores.softmax(dim=-1)
268
+ del attention_scores
269
+
270
+ hidden_states = torch.bmm(attention_probs, value)
271
+ hidden_states = attn.batch_to_head_dim(hidden_states)
272
+ del attention_probs
273
+
274
+ hidden_states = attn.to_out[0](hidden_states)
275
+
276
+ if input_ndim == 4:
277
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
278
+
279
+ if attn.residual_connection:
280
+ hidden_states = hidden_states + residual
281
+
282
+ hidden_states = hidden_states / attn.rescale_output_factor
283
+
284
+ return hidden_states
285
+
286
+ def replace_attention_processor(unet, clear = False):
287
+
288
+ for name, module in unet.named_modules():
289
+ if 'attn1' in name and 'to' not in name:
290
+ layer_type = name.split('.')[0].split('_')[0]
291
+
292
+ if not clear:
293
+ if layer_type == 'down':
294
+ module.processor = AttnReplaceProcessor(True, weights['down'])
295
+ elif layer_type == 'mid':
296
+ module.processor = AttnReplaceProcessor(True, weights['mid'])
297
+ elif layer_type == 'up':
298
+ module.processor = AttnReplaceProcessor(True, weights['up'])
299
+ else:
300
+ module.processor = AttnReplaceProcessor(False, 0.0)
301
+
302
+ def apply_prompt(meta_data, new_prompt):
303
+
304
+ caption, real_image_initial_latents, QT = meta_data
305
+
306
+ inference_steps = len(QT)
307
+
308
+ cond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
309
+ # uncond_prompt_embeds = pipe.encode_prompt(prompt=caption, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
310
+ new_prompt_embeds = pipe.encode_prompt(prompt=new_prompt, device="cuda", num_images_per_prompt=1, do_classifier_free_guidance=False)[0]
311
+
312
+ guidance_scale = 7.5
313
+ scheduler.set_timesteps(inference_steps, device="cuda")
314
+ timesteps = scheduler.timesteps
315
+
316
+ latents = torch.cat([real_image_initial_latents] * 2)
317
+
318
+ with torch.no_grad():
319
+ replace_attention_processor(pipe.unet)
320
+
321
+ for i, t in tqdm(enumerate(timesteps), total=len(timesteps), desc="Inference steps"):
322
+
323
+ modified_prompt_embeds = torch.cat([QT[i].unsqueeze(0), QT[i].unsqueeze(0), cond_prompt_embeds, new_prompt_embeds])
324
+ latent_model_input = torch.cat([latents] * 2)
325
+
326
+ noise_pred = pipe.unet(
327
+ latent_model_input,
328
+ t,
329
+ encoder_hidden_states=modified_prompt_embeds,
330
+ cross_attention_kwargs=None,
331
+ return_dict=False,
332
+ )[0]
333
+
334
+
335
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
336
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
337
+
338
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
339
+
340
+ replace_attention_processor(pipe.unet, True)
341
+
342
+ image = pipe.vae.decode(latents[1].unsqueeze(0) / pipe.vae.config.scaling_factor, return_dict=False)[0]
343
+ image = (image / 2.0 + 0.5).clamp(0.0, 1.0)
344
+ safety_checker_input = feature_extractor(image, return_tensors="pt", do_rescale=False).to("cuda")
345
+ image = safety_checker(images=[image], clip_input=safety_checker_input.pixel_values.to("cuda"))[0]
346
+ image_np = image[0].squeeze(0).float().permute(1,2,0).detach().cpu().numpy()
347
+ image_np = (image_np * 255).astype(np.uint8)
348
+
349
+ return image_np
350
+
351
+
352
+
353
+ def on_image_change(filepath):
354
+ # Extract the filename without extension
355
+ filename = os.path.splitext(os.path.basename(filepath))[0]
356
+
357
+ # Check if the filename is "example1" or "example2"
358
+ if filename in ["example1", "example2", "example3", "example4"]:
359
+ meta_data_raw = load_state_from_file(f"assets/{filename}.pkl")
360
+ _, _, QT_raw = meta_data_raw
361
+
362
+ global num_inference_steps
363
+ num_inference_steps = len(QT_raw)
364
+ scale_value = 7
365
+ new_prompt = ""
366
+
367
+ if filename == "example1":
368
+ scale_value = 7
369
+ new_prompt = "a photo of a tree, summer, colourful"
370
+
371
+ elif filename == "example2":
372
+ scale_value = 8
373
+ new_prompt = "a photo of a panda, two ears, white background"
374
+
375
+ elif filename == "example3":
376
+ scale_value = 7
377
+ new_prompt = "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"
378
+
379
+ elif filename == "example4":
380
+ scale_value = 7
381
+ new_prompt = "a photo of plastic bottle on some sand, beach background, sky background"
382
+
383
+ update_scale(scale_value)
384
+ img = apply_prompt(meta_data_raw, new_prompt)
385
+
386
+ return filepath, img, meta_data_raw, num_inference_steps, scale_value, scale_value
387
+
388
+ def update_value(value, key, res):
389
+ global weights
390
+ weights[key][res] = value
391
+
392
+ def update_step(value):
393
+ global num_inference_steps
394
+ num_inference_steps = value
395
+
396
+ def update_scale(scale):
397
+ values = [1.0] * 7
398
+
399
+ if scale == 9:
400
+ return values
401
+
402
+ reduction_steps = (9 - scale) * 0.5
403
+
404
+ for i in range(4): # There are 4 positions to reduce symmetrically
405
+ if reduction_steps >= 1:
406
+ values[i] = 0.0
407
+ values[-(i + 1)] = 0.0
408
+ reduction_steps -= 1
409
+ elif reduction_steps > 0:
410
+ values[i] = 0.5
411
+ values[-(i + 1)] = 0.5
412
+ break
413
+
414
+ global weights
415
+ index = 0
416
+
417
+ for outer_key, inner_dict in weights.items():
418
+ for inner_key in inner_dict:
419
+ inner_dict[inner_key] = values[index]
420
+ index += 1
421
+
422
+ return weights['down'][4096], weights['down'][1024], weights['down'][256], weights['mid'][64], weights['up'][256], weights['up'][1024], weights['up'][4096]
423
+
424
+
425
+ with gr.Blocks() as demo:
426
+ gr.Markdown(
427
+ '''
428
+ <div style="text-align: center;">
429
+ <div style="display: flex; justify-content: center;">
430
+ <img src="https://github.com/user-attachments/assets/55a38e74-ab93-4d80-91c8-0fa6130af45a" alt="Logo">
431
+ </div>
432
+ <h1>Out of Focus 1.0</h1>
433
+ <p style="font-size:16px;">Out of AI presents a flexible tool to manipulate your images. This is our first version of Image modification tool through prompt manipulation by reconstruction through diffusion inversion process</p>
434
+ </div>
435
+ <br>
436
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
437
+ <a href="https://www.buymeacoffee.com/outofai" target="_blank"><img src="https://img.shields.io/badge/-buy_me_a%C2%A0coffee-red?logo=buy-me-a-coffee" alt="Buy Me A Coffee"></a> &ensp;
438
+ <a href="https://twitter.com/OutofAi" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Ashleigh%20Watson"></a> &ensp;
439
+ <a href="https://twitter.com/banterless_ai" target="_blank"><img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Alex%20Nasa"></a>
440
+ </div>
441
+ '''
442
+ )
443
+ with gr.Row():
444
+ with gr.Column():
445
+
446
+ with gr.Row():
447
+ example_input = gr.Image(height=512, width=512, type="filepath", visible=False)
448
+ image_input = gr.Image(height=512, width=512, type="pil", label="Upload Source Image")
449
+ steps_slider = gr.Slider(minimum=5, maximum=25, step=5, value=num_inference_steps, label="Steps", info="Number of inference steps required to reconstruct and modify the image")
450
+ prompt_input = gr.Textbox(label="Prompt", info="Give an initial prompt in details, describing the image")
451
+ reconstruct_button = gr.Button("Reconstruct")
452
+ stop_button = gr.Button("Stop", variant="stop", interactive=False)
453
+ with gr.Column():
454
+ reconstructed_image = gr.Image(type="pil", label="Reconstructed")
455
+
456
+ with gr.Row():
457
+ invisible_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, visible=False)
458
+ interpolate_slider = gr.Slider(minimum=0, maximum=9, step=1, value=7, label="Cross-Attention Influence", info="Scales the related influence the source image has on the target image")
459
+ with gr.Row():
460
+ new_prompt_input = gr.Textbox(label="New Prompt", interactive=False, info="Manipulate the image by changing the prompt or word addition at the end, achieve the best results by swapping words instead of adding or removing in between")
461
+ with gr.Row():
462
+ apply_button = gr.Button("Generate Vision", variant="primary", interactive=False)
463
+ with gr.Row():
464
+ with gr.Accordion(label="Advanced Options", open=False):
465
+ gr.Markdown(
466
+ '''
467
+ <div style="text-align: center;">
468
+ <h1>Weight Adjustment</h1>
469
+ <p style="font-size:16px;">Specific Cross-Attention Influence weights can be manually modified for given resolutions (1.0 = Fully Source Attn 0.0 = Fully Target Attn)</p>
470
+ </div>
471
+ '''
472
+ )
473
+ down_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][4096], label="Self-Attn Down 64x64")
474
+ down_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][1024], label="Self-Attn Down 32x32")
475
+ down_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['down'][256], label="Self-Attn Down 16x16")
476
+ mid_slider_64 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['mid'][64], label="Self-Attn Mid 8x8")
477
+ up_slider_256 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][256], label="Self-Attn Up 16x16")
478
+ up_slider_1024 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][1024], label="Self-Attn Up 32x32")
479
+ up_slider_4096 = gr.Number(minimum=0.0, maximum=1.0, step=0.1, value=weights['up'][4096], label="Self-Attn Up 64x64")
480
+
481
+ with gr.Row():
482
+ show_case = gr.Examples(
483
+ examples=[
484
+ ["assets/example4.png", "a photo of plastic bottle on a rock, mountain background, sky background", "a photo of plastic bottle on some sand, beach background, sky background"],
485
+ ["assets/example1.png", "a photo of a tree, spring, foggy", "a photo of a tree, summer, colourful"],
486
+ ["assets/example2.png", "a photo of a cat, two ears, white background", "a photo of a panda, two ears, white background"],
487
+ ["assets/example3.png", "a digital illustration of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds", "a realistic photo of a female warrior, flowing dark purple or black hair, bronze shoulder armour, leather chest piece, sky background with clouds"],
488
+
489
+ ],
490
+ inputs=[example_input, prompt_input, new_prompt_input],
491
+ label=None
492
+ )
493
+
494
+ meta_data = gr.State()
495
+
496
+ example_input.change(
497
+ fn=on_image_change,
498
+ inputs=example_input,
499
+ outputs=[image_input, reconstructed_image, meta_data, steps_slider, invisible_slider, interpolate_slider]
500
+ ).then(
501
+ lambda: gr.update(interactive=True),
502
+ outputs=apply_button
503
+ ).then(
504
+ lambda: gr.update(interactive=True),
505
+ outputs=new_prompt_input
506
+ )
507
+ steps_slider.release(update_step, inputs=steps_slider)
508
+ interpolate_slider.release(update_scale, inputs=interpolate_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
509
+ invisible_slider.change(update_scale, inputs=invisible_slider, outputs=[down_slider_4096, down_slider_1024, down_slider_256, mid_slider_64, up_slider_256, up_slider_1024, up_slider_4096 ])
510
+
511
+ up_slider_4096.change(update_value, inputs=[up_slider_4096, gr.State('up'), gr.State(4096)])
512
+ up_slider_1024.change(update_value, inputs=[up_slider_1024, gr.State('up'), gr.State(1024)])
513
+ up_slider_256.change(update_value, inputs=[up_slider_256, gr.State('up'), gr.State(256)])
514
+
515
+ down_slider_4096.change(update_value, inputs=[down_slider_4096, gr.State('down'), gr.State(4096)])
516
+ down_slider_1024.change(update_value, inputs=[down_slider_1024, gr.State('down'), gr.State(1024)])
517
+ down_slider_256.change(update_value, inputs=[down_slider_256, gr.State('down'), gr.State(256)])
518
+
519
+ mid_slider_64.change(update_value, inputs=[mid_slider_64, gr.State('mid'), gr.State(64)])
520
+
521
+ reconstruct_button.click(reconstruct, inputs=[image_input, prompt_input], outputs=[reconstructed_image, new_prompt_input, meta_data]).then(
522
+ lambda: gr.update(interactive=True),
523
+ outputs=reconstruct_button
524
+ ).then(
525
+ lambda: gr.update(interactive=True),
526
+ outputs=new_prompt_input
527
+ ).then(
528
+ lambda: gr.update(interactive=True),
529
+ outputs=apply_button
530
+ ).then(
531
+ lambda: gr.update(interactive=False),
532
+ outputs=stop_button
533
+ )
534
+
535
+ reconstruct_button.click(
536
+ lambda: gr.update(interactive=False),
537
+ outputs=reconstruct_button
538
+ )
539
+
540
+ reconstruct_button.click(
541
+ lambda: gr.update(interactive=True),
542
+ outputs=stop_button
543
+ )
544
+
545
+ reconstruct_button.click(
546
+ lambda: gr.update(interactive=False),
547
+ outputs=apply_button
548
+ )
549
+
550
+ stop_button.click(
551
+ lambda: gr.update(interactive=False),
552
+ outputs=stop_button
553
+ )
554
+
555
+ apply_button.click(apply_prompt, inputs=[meta_data, new_prompt_input], outputs=reconstructed_image)
556
+ stop_button.click(stop_reconstruct)
557
+
558
+ if __name__ == "__main__":
559
+ parser = argparse.ArgumentParser()
560
+ parser.add_argument("--share", action="store_true")
561
+ args = parser.parse_args()
562
+ demo.queue()
563
+ demo.launch(share=args.share)
assets/example1.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd481563fee5830919786d31895653b35b44a486beb11881fd13cf98e213c184
3
+ size 3220274
assets/example1.png ADDED
assets/example2.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2c26bd70e19685eb33b6514a5f26da4c2d3d69e306f60fba021beb390e86f36
3
+ size 3220286
assets/example2.png ADDED
assets/example3.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e539a7ce84d036519fdef1dc6610c1de32cf70540ef96375915e457a74d8f25d
3
+ size 3220392
assets/example3.png ADDED
assets/example4.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d91c3c54c5d987ea15365e3dc79e2df203751d80f1548020881de0e024d8ad9d
3
+ size 3220316
assets/example4.png ADDED
assets/logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ diffusers
2
+ transformers
3
+ gradio
4
+ accelerate
5
+
6
+ --extra-index-url https://download.pytorch.org/whl/cu121
7
+ torch
8
+ torchvision
9
+ torchaudio