pcuenq HF staff commited on
Commit
3875a6e
1 Parent(s): bae1c0e

Update to v3

Browse files
Files changed (2) hide show
  1. app.py +126 -52
  2. requirements.txt +3 -3
app.py CHANGED
@@ -1,23 +1,26 @@
1
  import gradio as gr
2
- import open_clip
3
  import torch
 
 
 
4
  from PIL import Image
5
  from open_clip import tokenizer
6
- from rudalle import get_vae
7
- from einops import rearrange
8
- from huggingface_hub import hf_hub_download
9
- from modules import DenoiseUNet
10
  from arroz import Diffuzz, PriorModel
 
 
 
11
 
12
- model_repo = "pcuenq/Arroz_con_cosas"
13
- model_file = "model_1b_img.pt"
14
- prior_file = "prior_v1_1500k_ema_fp16.pt"
 
15
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
- device_text = "GPU 🔥" if torch.cuda.is_available() else "CPU 🥶"
18
 
19
  batch_size = 4
20
- latent_shape = (64, 64)
 
21
 
22
  generator_timesteps = 12
23
  generator_cfg = 5
@@ -98,61 +101,135 @@ def sample(model, c, x=None, negative_embeddings=None, mask=None, T=12, size=(32
98
 
99
  # Model loading
100
 
101
- vqmodel = get_vae().to(device)
102
- vqmodel.eval().requires_grad_(False)
 
103
 
 
104
  clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
105
  clip_model = clip_model.to(device).half().eval().requires_grad_(False)
106
 
107
- def encode(x):
108
- return vqmodel.model.encode((2 * x - 1))[-1][-1]
109
-
110
- def decode(img_seq, shape=(32,32)):
111
- img_seq = img_seq.view(img_seq.shape[0], -1)
112
- b, n = img_seq.shape
113
- one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float()
114
- z = (one_hot_indices @ vqmodel.model.quantize.embed.weight)
115
- z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])
116
- img = vqmodel.model.decode(z)
117
- img = (img.clamp(-1., 1.) + 1) * 0.5
118
- return img
119
-
120
- model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
121
- model = DenoiseUNet(num_labels=8192, c_clip=1024, c_hidden=1280, down_levels=[1, 2, 8, 32], up_levels=[32, 8, 2, 1])
122
- model = model.to(device).half()
123
- model.load_state_dict(torch.load(model_path, map_location=device))
124
- model.eval().requires_grad_()
125
 
126
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
127
  prior = PriorModel().to(device).half()
128
  prior.load_state_dict(torch.load(prior_path, map_location=device))
129
  prior.eval().requires_grad_(False)
 
 
 
 
 
 
 
 
130
  diffuzz = Diffuzz(device=device)
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # -----
133
 
134
  def infer(prompt, negative_prompt):
135
- tokenized_text = tokenizer.tokenize([prompt] * batch_size).to(device)
136
- negative_text = tokenizer.tokenize([negative_prompt] * batch_size).to(device)
137
  with torch.inference_mode():
138
- with torch.autocast(device_type="cuda"):
139
- clip_embeddings = clip_model.encode_text(tokenized_text)
140
- neg_clip_embeddings = clip_model.encode_text(negative_text)
 
 
 
 
 
 
 
141
 
142
- sampled_image_embeddings = diffuzz.sample(
143
- prior, {'c': clip_embeddings}, clip_embedding_shape,
 
144
  timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
145
  )[-1]
146
-
147
- images = sample(
148
- model, sampled_image_embeddings, negative_embeddings=neg_clip_embeddings,
149
- T=generator_timesteps, size=latent_shape, starting_t=0, temp_range=[2.0, 0.1],
150
- typical_filtering=False, typical_mass=0.2, typical_min_tokens=1,
151
- classifier_free_scale=generator_cfg, renoise_steps=generator_timesteps-1,
152
- renoise_mode="start"
153
- )
154
- images = decode(images[-1], latent_shape)
155
- return to_pil(images)
 
 
 
156
 
157
  css = """
158
  .gradio-container {
@@ -304,9 +381,6 @@ with block:
304
  Paella Demo
305
  </h1>
306
  </div>
307
- <p>
308
- Running on <b>{device_text}</b>
309
- </p>
310
  <p style="margin-bottom: 10px; font-size: 94%">
311
  Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
312
  </p>
@@ -321,7 +395,7 @@ with block:
321
  label="Enter your prompt",
322
  show_label=False,
323
  max_lines=1,
324
- placeholder="Enter your prompt",
325
  elem_id="prompt-text-input",
326
  ).style(
327
  border=(True, False, True, True),
@@ -332,7 +406,7 @@ with block:
332
  label="Enter your negative prompt",
333
  show_label=False,
334
  max_lines=1,
335
- placeholder="Enter a negative prompt",
336
  elem_id="negative-prompt-text-input",
337
  ).style(
338
  border=(True, False, True, True),
 
1
  import gradio as gr
 
2
  import torch
3
+ import open_clip
4
+ import torchvision
5
+ from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  from open_clip import tokenizer
8
+ from Paella.utils.modules import Paella
 
 
 
9
  from arroz import Diffuzz, PriorModel
10
+ from transformers import AutoTokenizer, T5EncoderModel
11
+ from Paella.src.vqgan import VQModel
12
+ from Paella.utils.alter_attention import replace_attention_layers
13
 
14
+ model_repo = "dome272/Paella"
15
+ model_file = "paella_v3.pt"
16
+ prior_file = "prior_v1.pt"
17
+ vqgan_file = "vqgan_f4.pt"
18
 
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
20
 
21
  batch_size = 4
22
+ latent_shape = (batch_size, 64, 64) # latent shape of the generated image, we are using an f4 vqgan and thus sampling 64x64 will result in 256x256
23
+ prior_timesteps, prior_cfg, prior_sampler, clip_embedding_shape = 60, 3.0, "ddpm", (batch_size, 1024)
24
 
25
  generator_timesteps = 12
26
  generator_cfg = 5
 
101
 
102
  # Model loading
103
 
104
+ # Load T5 on CPU
105
+ t5_tokenizer = AutoTokenizer.from_pretrained("google/byt5-xl")
106
+ t5_model = T5EncoderModel.from_pretrained("google/byt5-xl")
107
 
108
+ # Load other models on GPU
109
  clip_model, _, _ = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k')
110
  clip_model = clip_model.to(device).half().eval().requires_grad_(False)
111
 
112
+ clip_preprocess = torchvision.transforms.Compose([
113
+ torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
114
+ torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
115
+ ])
116
+
117
+ vqgan_path = hf_hub_download(repo_id=model_repo, filename=vqgan_file)
118
+ vqmodel = VQModel().to(device)
119
+ vqmodel.load_state_dict(torch.load(vqgan_path, map_location=device))
120
+ vqmodel.eval().requires_grad_(False)
 
 
 
 
 
 
 
 
 
121
 
122
  prior_path = hf_hub_download(repo_id=model_repo, filename=prior_file)
123
  prior = PriorModel().to(device).half()
124
  prior.load_state_dict(torch.load(prior_path, map_location=device))
125
  prior.eval().requires_grad_(False)
126
+
127
+ model_path = hf_hub_download(repo_id=model_repo, filename=model_file)
128
+ model = Paella(byt5_embd=2560)
129
+ model.load_state_dict(torch.load(model_path, map_location=device))
130
+ model.eval().requires_grad_().half()
131
+ replace_attention_layers(model)
132
+ model.to(device)
133
+
134
  diffuzz = Diffuzz(device=device)
135
 
136
+ @torch.inference_mode()
137
+ def decode(img_seq):
138
+ return vqmodel.decode_indices(img_seq)
139
+
140
+ @torch.inference_mode()
141
+ def embed_t5(text, t5_tokenizer, t5_model, final_device="cuda"):
142
+ device = t5_model.device
143
+ t5_tokens = t5_tokenizer(text, padding="longest", return_tensors="pt", max_length=768, truncation=True).input_ids.to(device)
144
+ t5_embeddings = t5_model(input_ids=t5_tokens).last_hidden_state.to(final_device)
145
+ return t5_embeddings
146
+
147
+ @torch.inference_mode()
148
+ def sample(model, model_inputs, latent_shape,
149
+ unconditional_inputs=None, init_x=None, steps=12, renoise_steps=None,
150
+ temperature = (0.7, 0.3), cfg=(8.0, 8.0),
151
+ mode = 'multinomial', # 'quant', 'multinomial', 'argmax'
152
+ t_start=1.0, t_end=0.0,
153
+ sampling_conditional_steps=None, sampling_quant_steps=None, attn_weights=None
154
+ ):
155
+ device = unconditional_inputs["byt5"].device
156
+ if sampling_conditional_steps is None:
157
+ sampling_conditional_steps = steps
158
+ if sampling_quant_steps is None:
159
+ sampling_quant_steps = steps
160
+ if renoise_steps is None:
161
+ renoise_steps = steps-1
162
+ if unconditional_inputs is None:
163
+ unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()}
164
+
165
+ init_noise = torch.randint(0, model.num_labels, size=latent_shape, device=device)
166
+ if init_x != None:
167
+ sampled = init_x
168
+ else:
169
+ sampled = init_noise.clone()
170
+ t_list = torch.linspace(t_start, t_end, steps+1)
171
+ temperatures = torch.linspace(temperature[0], temperature[1], steps)
172
+ cfgs = torch.linspace(cfg[0], cfg[1], steps)
173
+ for i, tv in enumerate(t_list[:steps]):
174
+ if i >= sampling_quant_steps:
175
+ mode = "quant"
176
+ t = torch.ones(latent_shape[0], device=device) * tv
177
+
178
+ logits = model(sampled, t, **model_inputs, attn_weights=attn_weights)
179
+ if cfg is not None and i < sampling_conditional_steps:
180
+ logits = logits * cfgs[i] + model(sampled, t, **unconditional_inputs) * (1-cfgs[i])
181
+ scores = logits.div(temperatures[i]).softmax(dim=1)
182
+
183
+ if mode == 'argmax':
184
+ sampled = logits.argmax(dim=1)
185
+ elif mode == 'multinomial':
186
+ sampled = scores.permute(0, 2, 3, 1).reshape(-1, logits.size(1))
187
+ sampled = torch.multinomial(sampled, 1)[:, 0].view(logits.size(0), *logits.shape[2:])
188
+ elif mode == 'quant':
189
+ sampled = scores.permute(0, 2, 3, 1) @ vqmodel.vquantizer.codebook.weight.data
190
+ sampled = vqmodel.vquantizer.forward(sampled, dim=-1)[-1]
191
+ else:
192
+ raise Exception(f"Mode '{mode}' not supported, use: 'quant', 'multinomial' or 'argmax'")
193
+
194
+ if i < renoise_steps:
195
+ t_next = torch.ones(latent_shape[0], device=device) * t_list[i+1]
196
+ sampled = model.add_noise(sampled, t_next, random_x=init_noise)[0]
197
+ return sampled
198
+
199
  # -----
200
 
201
  def infer(prompt, negative_prompt):
202
+ text = tokenizer.tokenize([prompt] * latent_shape[0]).to(device)
 
203
  with torch.inference_mode():
204
+ if negative_prompt:
205
+ clip_text_tokens_uncond = tokenizer.tokenize([negative_prompt] * len(text)).to(device)
206
+ t5_embeddings_uncond = embed_t5([negative_prompt] * len(text), t5_tokenizer, t5_model)
207
+ else:
208
+ clip_text_tokens_uncond = tokenizer.tokenize([""] * len(text)).to(device)
209
+ t5_embeddings_uncond = embed_t5([""] * len(text), t5_tokenizer, t5_model)
210
+
211
+ t5_embeddings = embed_t5([prompt] * latent_shape[0], t5_tokenizer, t5_model)
212
+ clip_text_embeddings = clip_model.encode_text(text)
213
+ clip_text_embeddings_uncond = clip_model.encode_text(clip_text_tokens_uncond)
214
 
215
+ with torch.autocast(device_type="cuda"):
216
+ clip_image_embeddings = diffuzz.sample(
217
+ prior, {'c': clip_text_embeddings}, clip_embedding_shape,
218
  timesteps=prior_timesteps, cfg=prior_cfg, sampler=prior_sampler
219
  )[-1]
220
+
221
+ attn_weights = torch.ones((t5_embeddings.shape[1]))
222
+ attn_weights[-4:] = 0.4 # reweigh attention weights for image embeddings --> less influence
223
+ attn_weights[:-4] = 1.2 # reweigh attention weights for the rest --> more influence
224
+ attn_weights = attn_weights.to(device)
225
+
226
+ sampled_tokens = sample(model,
227
+ model_inputs={'byt5': t5_embeddings, 'clip': clip_text_embeddings, 'clip_image': clip_image_embeddings}, unconditional_inputs={'byt5': t5_embeddings_uncond, 'clip': clip_text_embeddings_uncond, 'clip_image': None},
228
+ temperature=(1.2, 0.2), cfg=(8,8), steps=32, renoise_steps=26, latent_shape=latent_shape, t_start=1.0, t_end=0.0,
229
+ mode="multinomial", sampling_conditional_steps=20, attn_weights=attn_weights)
230
+
231
+ sampled = decode(sampled_tokens)
232
+ return to_pil(sampled.clamp(0, 1))
233
 
234
  css = """
235
  .gradio-container {
 
381
  Paella Demo
382
  </h1>
383
  </div>
 
 
 
384
  <p style="margin-bottom: 10px; font-size: 94%">
385
  Paella is a novel text-to-image model that uses a compressed quantized latent space, based on a f8 VQGAN, and a masked training objective to achieve fast generation in ~10 inference steps.
386
  </p>
 
395
  label="Enter your prompt",
396
  show_label=False,
397
  max_lines=1,
398
+ placeholder="an image of a shiba inu, donning a spacesuit and helmet, traversing the uncharted terrain of a distant, extraterrestrial world, as a symbol of the intrepid spirit of exploration and the unrelenting curiosity that drives humanity to push beyond the bounds of the known",
399
  elem_id="prompt-text-input",
400
  ).style(
401
  border=(True, False, True, True),
 
406
  label="Enter your negative prompt",
407
  show_label=False,
408
  max_lines=1,
409
+ placeholder="low quality, low resolution, bad image, blurry, blur",
410
  elem_id="negative-prompt-text-input",
411
  ).style(
412
  border=(True, False, True, True),
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  torch
2
  open_clip_torch
3
- einops
4
  Pillow
5
  huggingface_hub
6
- git+https://github.com/pcuenca/Arroz-Con-Cosas
7
- git+https://github.com/ai-forever/ru-dalle
 
 
1
  torch
2
  open_clip_torch
 
3
  Pillow
4
  huggingface_hub
5
+ git+https://github.com/pabloppp/pytorch-tools
6
+ git+https://github.com/pabloppp/Arroz-Con-Cosas
7
+ git+https://github.com/fbcotter/pytorch_wavelets