Spaces:
Runtime error
Runtime error
Update to v3
Browse files- app.py +126 -52
- requirements.txt +3 -3
app.py
CHANGED
@@ -1,23 +1,26 @@
|
|
1 |
import gradio as gr
|
2 |
-
import open_clip
|
3 |
import torch
|
|
|
|
|
|
|
4 |
from PIL import Image
|
5 |
from open_clip import tokenizer
|
6 |
-
from
|
7 |
-
from einops import rearrange
|
8 |
-
from huggingface_hub import hf_hub_download
|
9 |
-
from modules import DenoiseUNet
|
10 |
from arroz import Diffuzz, PriorModel
|
|
|
|
|
|
|
11 |
|
12 |
-
model_repo = "
|
13 |
-
model_file = "
|
14 |
-
prior_file = "
|
|
|
15 |
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
-
device_text = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
|
18 |
|
19 |
batch_size = 4
|
20 |
-
latent_shape = (64, 64)
|
|
|
21 |
|
22 |
generator_timesteps = 12
|
23 |
generator_cfg = 5
|
@@ -98,61 +101,135 @@ def sample(model, c, x=None, negative_embeddings=None, mask=None, T=12, size=(32
|
|
98 |
|
99 |
# Model loading
|
100 |
|
101 |
-
|
102 |
-
|
|
|
103 |
|
|
|
104 |
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
|
105 |
clip_model = clip_model.to(device).half().eval().requires_grad_(False)
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
img = vqmodel.model.decode(z)
|
117 |
-
img = (img.clamp(-1., 1.) + 1) * 0.5
|
118 |
-
return img
|
119 |
-
|
120 |
-
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
121 |
-
model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1])
|
122 |
-
model = model.to(device).half()
|
123 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
124 |
-
model.eval().requires_grad_()
|
125 |
|
126 |
prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
|
127 |
prior = PriorModel().to(device).half()
|
128 |
prior.load_state_dict(torch.load(prior_path, map_location=device))
|
129 |
prior.eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
diffuzz = Diffuzz(device=device)
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
# -----
|
133 |
|
134 |
def infer(prompt, negative_prompt):
|
135 |
-
|
136 |
-
negative_text = tokenizer.tokenize([negative_prompt] * batch_size).to(device)
|
137 |
with torch.inference_mode():
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
|
|
|
144 |
timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
|
145 |
)[-1]
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
|
|
|
|
|
|
156 |
|
157 |
css = """
|
158 |
.gradio-container {
|
@@ -304,9 +381,6 @@ with block:
|
|
304 |
Paella Demo
|
305 |
</h1>
|
306 |
</div>
|
307 |
-
<p>
|
308 |
-
Running on <b>{device_text}</b>
|
309 |
-
</p>
|
310 |
<p style="margin-bottom: 10px; font-size: 94%">
|
311 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
312 |
</p>
|
@@ -321,7 +395,7 @@ with block:
|
|
321 |
label="Enter your prompt",
|
322 |
show_label=False,
|
323 |
max_lines=1,
|
324 |
-
placeholder="
|
325 |
elem_id="prompt-text-input",
|
326 |
).style(
|
327 |
border=(True, False, True, True),
|
@@ -332,7 +406,7 @@ with block:
|
|
332 |
label="Enter your negative prompt",
|
333 |
show_label=False,
|
334 |
max_lines=1,
|
335 |
-
placeholder="
|
336 |
elem_id="negative-prompt-text-input",
|
337 |
).style(
|
338 |
border=(True, False, True, True),
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import open_clip
|
4 |
+
import torchvision
|
5 |
+
from huggingface_hub import hf_hub_download
|
6 |
from PIL import Image
|
7 |
from open_clip import tokenizer
|
8 |
+
from Paella.utils.modules import Paella
|
|
|
|
|
|
|
9 |
from arroz import Diffuzz, PriorModel
|
10 |
+
from transformers import AutoTokenizer, T5EncoderModel
|
11 |
+
from Paella.src.vqgan import VQModel
|
12 |
+
from Paella.utils.alter_attention import replace_attention_layers
|
13 |
|
14 |
+
model_repo = "dome272/Paella"
|
15 |
+
model_file = "paella_v3.pt"
|
16 |
+
prior_file = "prior_v1.pt"
|
17 |
+
vqgan_file = "vqgan_f4.pt"
|
18 |
|
19 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
20 |
|
21 |
batch_size = 4
|
22 |
+
latent_shape = (batch_size, 64, 64) # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
|
23 |
+
prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (batch_size, 1024)
|
24 |
|
25 |
generator_timesteps = 12
|
26 |
generator_cfg = 5
|
|
|
101 |
|
102 |
# Model loading
|
103 |
|
104 |
+
# Load T5 on CPU
|
105 |
+
t5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-xl")
|
106 |
+
t5_model = T5EncoderModel.from_pretrained("google/byt5-xl")
|
107 |
|
108 |
+
# Load other models on GPU
|
109 |
clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
|
110 |
clip_model = clip_model.to(device).half().eval().requires_grad_(False)
|
111 |
|
112 |
+
clip_preprocess = torchvision.transforms.Compose([
|
113 |
+
torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
|
114 |
+
torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
|
115 |
+
])
|
116 |
+
|
117 |
+
vqgan_path = hf_hub_download(repo_id=model_repo, filename=vqgan_file)
|
118 |
+
vqmodel = VQModel().to(device)
|
119 |
+
vqmodel.load_state_dict(torch.load(vqgan_path, map_location=device))
|
120 |
+
vqmodel.eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
|
123 |
prior = PriorModel().to(device).half()
|
124 |
prior.load_state_dict(torch.load(prior_path, map_location=device))
|
125 |
prior.eval().requires_grad_(False)
|
126 |
+
|
127 |
+
model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
|
128 |
+
model = Paella(byt5_embd=2560)
|
129 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
130 |
+
model.eval().requires_grad_().half()
|
131 |
+
replace_attention_layers(model)
|
132 |
+
model.to(device)
|
133 |
+
|
134 |
diffuzz = Diffuzz(device=device)
|
135 |
|
136 |
+
@torch.inference_mode()
|
137 |
+
def decode(img_seq):
|
138 |
+
return vqmodel.decode_indices(img_seq)
|
139 |
+
|
140 |
+
@torch.inference_mode()
|
141 |
+
def embed_t5(text, t5_tokenizer, t5_model, final_device="cuda"):
|
142 |
+
device = t5_model.device
|
143 |
+
t5_tokens = t5_tokenizer(text, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
|
144 |
+
t5_embeddings = t5_model(input_ids=t5_tokens).last_hidden_state.to(final_device)
|
145 |
+
return t5_embeddings
|
146 |
+
|
147 |
+
@torch.inference_mode()
|
148 |
+
def sample(model, model_inputs, latent_shape,
|
149 |
+
unconditional_inputs=None, init_x=None, steps=12, renoise_steps=None,
|
150 |
+
temperature = (0.7, 0.3), cfg=(8.0, 8.0),
|
151 |
+
mode = 'multinomial', # 'quant', 'multinomial', 'argmax'
|
152 |
+
t_start=1.0, t_end=0.0,
|
153 |
+
sampling_conditional_steps=None, sampling_quant_steps=None, attn_weights=None
|
154 |
+
):
|
155 |
+
device = unconditional_inputs["byt5"].device
|
156 |
+
if sampling_conditional_steps is None:
|
157 |
+
sampling_conditional_steps = steps
|
158 |
+
if sampling_quant_steps is None:
|
159 |
+
sampling_quant_steps = steps
|
160 |
+
if renoise_steps is None:
|
161 |
+
renoise_steps = steps-1
|
162 |
+
if unconditional_inputs is None:
|
163 |
+
unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
|
164 |
+
|
165 |
+
init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
|
166 |
+
if init_x != None:
|
167 |
+
sampled = init_x
|
168 |
+
else:
|
169 |
+
sampled = init_noise.clone()
|
170 |
+
t_list = torch.linspace(t_start, t_end, steps+1)
|
171 |
+
temperatures = torch.linspace(temperature[0], temperature[1], steps)
|
172 |
+
cfgs = torch.linspace(cfg[0], cfg[1], steps)
|
173 |
+
for i, tv in enumerate(t_list[:steps]):
|
174 |
+
if i >= sampling_quant_steps:
|
175 |
+
mode = "quant"
|
176 |
+
t = torch.ones(latent_shape[0], device=device) * tv
|
177 |
+
|
178 |
+
logits = model(sampled, t, **model_inputs, attn_weights=attn_weights)
|
179 |
+
if cfg is not None and i < sampling_conditional_steps:
|
180 |
+
logits = logits * cfgs[i] + model(sampled, t, **unconditional_inputs) * (1-cfgs[i])
|
181 |
+
scores = logits.div(temperatures[i]).softmax(dim=1)
|
182 |
+
|
183 |
+
if mode == 'argmax':
|
184 |
+
sampled = logits.argmax(dim=1)
|
185 |
+
elif mode == 'multinomial':
|
186 |
+
sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
|
187 |
+
sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
|
188 |
+
elif mode == 'quant':
|
189 |
+
sampled = scores.permute(0, 2, 3, 1) @ vqmodel.vquantizer.codebook.weight.data
|
190 |
+
sampled = vqmodel.vquantizer.forward(sampled, dim=-1)[-1]
|
191 |
+
else:
|
192 |
+
raise Exception(f"Mode '{mode}' not supported, use: 'quant', 'multinomial' or 'argmax'")
|
193 |
+
|
194 |
+
if i < renoise_steps:
|
195 |
+
t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
|
196 |
+
sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
|
197 |
+
return sampled
|
198 |
+
|
199 |
# -----
|
200 |
|
201 |
def infer(prompt, negative_prompt):
|
202 |
+
text = tokenizer.tokenize([prompt] * latent_shape[0]).to(device)
|
|
|
203 |
with torch.inference_mode():
|
204 |
+
if negative_prompt:
|
205 |
+
clip_text_tokens_uncond = tokenizer.tokenize([negative_prompt] * len(text)).to(device)
|
206 |
+
t5_embeddings_uncond = embed_t5([negative_prompt] * len(text), t5_tokenizer, t5_model)
|
207 |
+
else:
|
208 |
+
clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
|
209 |
+
t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model)
|
210 |
+
|
211 |
+
t5_embeddings = embed_t5([prompt] * latent_shape[0], t5_tokenizer, t5_model)
|
212 |
+
clip_text_embeddings = clip_model.encode_text(text)
|
213 |
+
clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
|
214 |
|
215 |
+
with torch.autocast(device_type="cuda"):
|
216 |
+
clip_image_embeddings = diffuzz.sample(
|
217 |
+
prior, {'c': clip_text_embeddings}, clip_embedding_shape,
|
218 |
timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
|
219 |
)[-1]
|
220 |
+
|
221 |
+
attn_weights = torch.ones((t5_embeddings.shape[1]))
|
222 |
+
attn_weights[-4:] = 0.4 # reweigh attention weights for image embeddings --> less influence
|
223 |
+
attn_weights[:-4] = 1.2 # reweigh attention weights for the rest --> more influence
|
224 |
+
attn_weights = attn_weights.to(device)
|
225 |
+
|
226 |
+
sampled_tokens = sample(model,
|
227 |
+
model_inputs={'byt5': t5_embeddings, 'clip': clip_text_embeddings, 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond, 'clip': clip_text_embeddings_uncond, 'clip_image': None},
|
228 |
+
temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
|
229 |
+
mode="multinomial", sampling_conditional_steps=20, attn_weights=attn_weights)
|
230 |
+
|
231 |
+
sampled = decode(sampled_tokens)
|
232 |
+
return to_pil(sampled.clamp(0, 1))
|
233 |
|
234 |
css = """
|
235 |
.gradio-container {
|
|
|
381 |
Paella Demo
|
382 |
</h1>
|
383 |
</div>
|
|
|
|
|
|
|
384 |
<p style="margin-bottom: 10px; font-size: 94%">
|
385 |
Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
|
386 |
</p>
|
|
|
395 |
label="Enter your prompt",
|
396 |
show_label=False,
|
397 |
max_lines=1,
|
398 |
+
placeholder="an image of a shiba inu, donning a spacesuit and helmet, traversing the uncharted terrain of a distant, extraterrestrial world, as a symbol of the intrepid spirit of exploration and the unrelenting curiosity that drives humanity to push beyond the bounds of the known",
|
399 |
elem_id="prompt-text-input",
|
400 |
).style(
|
401 |
border=(True, False, True, True),
|
|
|
406 |
label="Enter your negative prompt",
|
407 |
show_label=False,
|
408 |
max_lines=1,
|
409 |
+
placeholder="low quality, low resolution, bad image, blurry, blur",
|
410 |
elem_id="negative-prompt-text-input",
|
411 |
).style(
|
412 |
border=(True, False, True, True),
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
torch
|
2 |
open_clip_torch
|
3 |
-
einops
|
4 |
Pillow
|
5 |
huggingface_hub
|
6 |
-
git+https://github.com/
|
7 |
-
git+https://github.com/
|
|
|
|
1 |
torch
|
2 |
open_clip_torch
|
|
|
3 |
Pillow
|
4 |
huggingface_hub
|
5 |
+
git+https://github.com/pabloppp/pytorch-tools
|
6 |
+
git+https://github.com/pabloppp/Arroz-Con-Cosas
|
7 |
+
git+https://github.com/fbcotter/pytorch_wavelets
|