Upload 3 files
Browse files- app.py +22 -0
- magic_mix.py +202 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from magic_mix import magic_mix
|
3 |
+
|
4 |
+
iface = gr.Interface(
|
5 |
+
description = "Implementation of MagicMix: Semantic Mixing with Diffusion Models paper",
|
6 |
+
article = "<p style='text-align: center'><a href='https://github.com/daspartho/MagicMix' target='_blank'>Github</a></p>",
|
7 |
+
fn=magic_mix,
|
8 |
+
inputs=[
|
9 |
+
gr.Image(shape=(512,512), type="pil"),
|
10 |
+
gr.Text(),
|
11 |
+
gr.Slider(value=0.3,minimum=0, maximum=1, step=0.1),
|
12 |
+
gr.Slider(value=0.5,minimum=0, maximum=1, step=0.1),
|
13 |
+
gr.Slider(value=0.5,minimum=0, maximum=1, step=0.1),
|
14 |
+
gr.Number(value=42, maximum=2**64-1),
|
15 |
+
gr.Slider(value=50),
|
16 |
+
gr.Slider(value=7.5, minimum=1, maximum=15, step=0.1),
|
17 |
+
],
|
18 |
+
outputs=gr.Image(),
|
19 |
+
title="MagicMix"
|
20 |
+
)
|
21 |
+
|
22 |
+
iface.launch()
|
magic_mix.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
2 |
+
from transformers import CLIPTextModel, CLIPTokenizer, logging
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms as tfms
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
# Supress some unnecessary warnings when loading the CLIPTextModel
|
9 |
+
logging.set_verbosity_error()
|
10 |
+
|
11 |
+
# Set device
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
# Loading components we'll use
|
15 |
+
|
16 |
+
tokenizer = CLIPTokenizer.from_pretrained(
|
17 |
+
"openai/clip-vit-large-patch14",
|
18 |
+
)
|
19 |
+
|
20 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
21 |
+
"openai/clip-vit-large-patch14",
|
22 |
+
).to(device)
|
23 |
+
|
24 |
+
vae = AutoencoderKL.from_pretrained(
|
25 |
+
"CompVis/stable-diffusion-v1-4",
|
26 |
+
subfolder = "vae",
|
27 |
+
).to(device)
|
28 |
+
|
29 |
+
unet = UNet2DConditionModel.from_pretrained(
|
30 |
+
"CompVis/stable-diffusion-v1-4",
|
31 |
+
subfolder = "unet",
|
32 |
+
).to(device)
|
33 |
+
|
34 |
+
beta_start,beta_end = 0.00085,0.012
|
35 |
+
scheduler = DDIMScheduler(
|
36 |
+
beta_start=beta_start,
|
37 |
+
beta_end=beta_end,
|
38 |
+
beta_schedule="scaled_linear",
|
39 |
+
num_train_timesteps=1000,
|
40 |
+
clip_sample=False,
|
41 |
+
set_alpha_to_one=False,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
# convert PIL image to latents
|
46 |
+
def encode(img):
|
47 |
+
with torch.no_grad():
|
48 |
+
latent = vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(device)*2-1)
|
49 |
+
latent = 0.18215 * latent.latent_dist.sample()
|
50 |
+
return latent
|
51 |
+
|
52 |
+
|
53 |
+
# convert latents to PIL image
|
54 |
+
def decode(latent):
|
55 |
+
latent = (1 / 0.18215) * latent
|
56 |
+
with torch.no_grad():
|
57 |
+
img = vae.decode(latent).sample
|
58 |
+
img = (img / 2 + 0.5).clamp(0, 1)
|
59 |
+
img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
|
60 |
+
img = (img * 255).round().astype("uint8")
|
61 |
+
return Image.fromarray(img[0])
|
62 |
+
|
63 |
+
|
64 |
+
# convert prompt into text embeddings, also unconditional embeddings
|
65 |
+
def prep_text(prompt):
|
66 |
+
|
67 |
+
text_input = tokenizer(
|
68 |
+
prompt,
|
69 |
+
padding="max_length",
|
70 |
+
max_length=tokenizer.model_max_length,
|
71 |
+
truncation=True,
|
72 |
+
return_tensors="pt",
|
73 |
+
)
|
74 |
+
|
75 |
+
text_embedding = text_encoder(
|
76 |
+
text_input.input_ids.to(device)
|
77 |
+
)[0]
|
78 |
+
|
79 |
+
uncond_input = tokenizer(
|
80 |
+
"",
|
81 |
+
padding="max_length",
|
82 |
+
max_length=tokenizer.model_max_length,
|
83 |
+
truncation=True,
|
84 |
+
return_tensors="pt",
|
85 |
+
)
|
86 |
+
|
87 |
+
uncond_embedding = text_encoder(
|
88 |
+
uncond_input.input_ids.to(device)
|
89 |
+
)[0]
|
90 |
+
|
91 |
+
return torch.cat([uncond_embedding, text_embedding])
|
92 |
+
|
93 |
+
|
94 |
+
def magic_mix(
|
95 |
+
img, # specifies the layout semantics
|
96 |
+
prompt, # specifies the content semantics
|
97 |
+
kmin=0.3,
|
98 |
+
kmax=0.6,
|
99 |
+
v=0.5, # interpolation constant
|
100 |
+
seed=42,
|
101 |
+
steps=50,
|
102 |
+
guidance_scale=7.5,
|
103 |
+
):
|
104 |
+
|
105 |
+
tmin = steps- int(kmin*steps)
|
106 |
+
tmax = steps- int(kmax*steps)
|
107 |
+
|
108 |
+
text_embeddings = prep_text(prompt)
|
109 |
+
|
110 |
+
scheduler.set_timesteps(steps)
|
111 |
+
|
112 |
+
width, height = img.size
|
113 |
+
encoded = encode(img)
|
114 |
+
|
115 |
+
torch.manual_seed(seed)
|
116 |
+
noise = torch.randn(
|
117 |
+
(1,unet.in_channels,height // 8,width // 8),
|
118 |
+
).to(device)
|
119 |
+
|
120 |
+
latents = scheduler.add_noise(
|
121 |
+
encoded,
|
122 |
+
noise,
|
123 |
+
timesteps=scheduler.timesteps[tmax]
|
124 |
+
)
|
125 |
+
|
126 |
+
input = torch.cat([latents]*2)
|
127 |
+
|
128 |
+
input = scheduler.scale_model_input(input, scheduler.timesteps[tmax])
|
129 |
+
|
130 |
+
with torch.no_grad():
|
131 |
+
pred = unet(
|
132 |
+
input,
|
133 |
+
scheduler.timesteps[tmax],
|
134 |
+
encoder_hidden_states=text_embeddings,
|
135 |
+
).sample
|
136 |
+
|
137 |
+
pred_uncond, pred_text = pred.chunk(2)
|
138 |
+
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
|
139 |
+
|
140 |
+
latents = scheduler.step(pred, scheduler.timesteps[tmax], latents).prev_sample
|
141 |
+
|
142 |
+
for i, t in enumerate(tqdm(scheduler.timesteps)):
|
143 |
+
if i > tmax:
|
144 |
+
if i < tmin: # layout generation phase
|
145 |
+
orig_latents = scheduler.add_noise(
|
146 |
+
encoded,
|
147 |
+
noise,
|
148 |
+
timesteps=t
|
149 |
+
)
|
150 |
+
|
151 |
+
input = (v*latents) + (1-v)*orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
|
152 |
+
input = torch.cat([input]*2)
|
153 |
+
|
154 |
+
else: # content generation phase
|
155 |
+
input = torch.cat([latents]*2)
|
156 |
+
|
157 |
+
input = scheduler.scale_model_input(input, t)
|
158 |
+
|
159 |
+
with torch.no_grad():
|
160 |
+
pred = unet(
|
161 |
+
input,
|
162 |
+
t,
|
163 |
+
encoder_hidden_states=text_embeddings,
|
164 |
+
).sample
|
165 |
+
|
166 |
+
pred_uncond, pred_text = pred.chunk(2)
|
167 |
+
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
|
168 |
+
|
169 |
+
latents = scheduler.step(pred, t, latents).prev_sample
|
170 |
+
|
171 |
+
return decode(latents)
|
172 |
+
|
173 |
+
if __name__ == "__main__":
|
174 |
+
|
175 |
+
import argparse
|
176 |
+
|
177 |
+
parser = argparse.ArgumentParser()
|
178 |
+
|
179 |
+
parser.add_argument("img_file", type=str, help="image file to provide the layout semantics for the mixing process")
|
180 |
+
parser.add_argument("prompt", type=str, help="prompt to provide the content semantics for the mixing process")
|
181 |
+
parser.add_argument("out_file", type=str, help="filename to save the generation to")
|
182 |
+
parser.add_argument("--kmin", type=float, default=0.3)
|
183 |
+
parser.add_argument("--kmax", type=float, default=0.6)
|
184 |
+
parser.add_argument("--v", type=float, default=0.5)
|
185 |
+
parser.add_argument("--seed", type=int, default=42)
|
186 |
+
parser.add_argument("--steps", type=int, default=50)
|
187 |
+
parser.add_argument("--guidance_scale", type=float, default=7.5)
|
188 |
+
|
189 |
+
args = parser.parse_args()
|
190 |
+
|
191 |
+
img = Image.open(args.img_file)
|
192 |
+
out_img = magic_mix(
|
193 |
+
img,
|
194 |
+
args.prompt,
|
195 |
+
args.kmin,
|
196 |
+
args.kmax,
|
197 |
+
args.v,
|
198 |
+
args.seed,
|
199 |
+
args.steps,
|
200 |
+
args.guidance_scale
|
201 |
+
)
|
202 |
+
out_img.save(args.out_file)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
diffusers
|
4 |
+
transformers
|
5 |
+
accelerate
|
6 |
+
tqdm
|
7 |
+
pillow
|
8 |
+
gradio
|