Spaces:
Running
Running
fix: correct clip params
Browse files
tools/inference/log_inference_samples.ipynb
CHANGED
@@ -24,25 +24,6 @@
|
|
24 |
"from dalle_mini.text import TextNormalizer"
|
25 |
]
|
26 |
},
|
27 |
-
{
|
28 |
-
"cell_type": "code",
|
29 |
-
"execution_count": null,
|
30 |
-
"id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
|
31 |
-
"metadata": {},
|
32 |
-
"outputs": [],
|
33 |
-
"source": [
|
34 |
-
"run_ids = ['3kaut6e8']\n",
|
35 |
-
"# Alamy - 3kaut6e8\n",
|
36 |
-
"# YFCC - to do\n",
|
37 |
-
"# HF spaces - 4oh3u7ca\n",
|
38 |
-
"ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
|
39 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
|
40 |
-
"normalize_text = False\n",
|
41 |
-
"latest_only = True # log only latest or all versions\n",
|
42 |
-
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
43 |
-
"add_clip_32 = False"
|
44 |
-
]
|
45 |
-
},
|
46 |
{
|
47 |
"cell_type": "code",
|
48 |
"execution_count": null,
|
@@ -50,13 +31,9 @@
|
|
50 |
"metadata": {},
|
51 |
"outputs": [],
|
52 |
"source": [
|
53 |
-
"run_ids = ['
|
54 |
-
"# poorly shuffled 1nj161cl\n",
|
55 |
-
"# well shuffled he9rrc3q\n",
|
56 |
-
"# non normalized 1fwxpyfh ! requires changing normalize_text\n",
|
57 |
"ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
|
58 |
-
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384',
|
59 |
-
"normalize_text = True\n",
|
60 |
"latest_only = True # log only latest or all versions\n",
|
61 |
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
62 |
"add_clip_32 = False"
|
@@ -85,7 +62,7 @@
|
|
85 |
"batch_size = 8\n",
|
86 |
"num_images = 128\n",
|
87 |
"top_k = 8\n",
|
88 |
-
"text_normalizer = TextNormalizer()
|
89 |
"padding_item = 'NONE'\n",
|
90 |
"seed = random.randint(0, 2**32-1)\n",
|
91 |
"key = jax.random.PRNGKey(seed)\n",
|
@@ -230,7 +207,7 @@
|
|
230 |
"outputs": [],
|
231 |
"source": [
|
232 |
"run_id = run_ids[0]\n",
|
233 |
-
"# TODO:
|
234 |
]
|
235 |
},
|
236 |
{
|
@@ -287,7 +264,7 @@
|
|
287 |
"\n",
|
288 |
" # process one batch of captions\n",
|
289 |
" for batch in tqdm(samples):\n",
|
290 |
-
" processed_prompts = [text_normalizer(x) for x in batch] if normalize_text else list(batch)\n",
|
291 |
"\n",
|
292 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
293 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
@@ -296,7 +273,7 @@
|
|
296 |
"\n",
|
297 |
" # generate images\n",
|
298 |
" images = []\n",
|
299 |
-
" pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=
|
300 |
" for i in pbar:\n",
|
301 |
" key, subkey = jax.random.split(key)\n",
|
302 |
" encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
|
@@ -312,7 +289,7 @@
|
|
312 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
313 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
314 |
" clip_inputs = shard(clip_inputs)\n",
|
315 |
-
" logits = p_clip(clip_inputs,
|
316 |
" logits = logits.reshape(-1, num_images)\n",
|
317 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
318 |
" logits = jax.device_get(logits)\n",
|
@@ -348,6 +325,14 @@
|
|
348 |
" wandb.finish()\n",
|
349 |
" run = None # ensure we don't log on this run"
|
350 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
351 |
}
|
352 |
],
|
353 |
"metadata": {
|
|
|
24 |
"from dalle_mini.text import TextNormalizer"
|
25 |
]
|
26 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
{
|
28 |
"cell_type": "code",
|
29 |
"execution_count": null,
|
|
|
31 |
"metadata": {},
|
32 |
"outputs": [],
|
33 |
"source": [
|
34 |
+
"run_ids = ['63otg87g']\n",
|
|
|
|
|
|
|
35 |
"ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
|
36 |
+
"VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
|
|
|
37 |
"latest_only = True # log only latest or all versions\n",
|
38 |
"suffix = '' # mainly for duplicate inference runs with a deleted version\n",
|
39 |
"add_clip_32 = False"
|
|
|
62 |
"batch_size = 8\n",
|
63 |
"num_images = 128\n",
|
64 |
"top_k = 8\n",
|
65 |
+
"text_normalizer = TextNormalizer()\n",
|
66 |
"padding_item = 'NONE'\n",
|
67 |
"seed = random.randint(0, 2**32-1)\n",
|
68 |
"key = jax.random.PRNGKey(seed)\n",
|
|
|
207 |
"outputs": [],
|
208 |
"source": [
|
209 |
"run_id = run_ids[0]\n",
|
210 |
+
"# TODO: loop over runs"
|
211 |
]
|
212 |
},
|
213 |
{
|
|
|
264 |
"\n",
|
265 |
" # process one batch of captions\n",
|
266 |
" for batch in tqdm(samples):\n",
|
267 |
+
" processed_prompts = [text_normalizer(x) for x in batch] if model.config.normalize_text else list(batch)\n",
|
268 |
"\n",
|
269 |
" # repeat the prompts to distribute over each device and tokenize\n",
|
270 |
" processed_prompts = processed_prompts * jax.device_count()\n",
|
|
|
273 |
"\n",
|
274 |
" # generate images\n",
|
275 |
" images = []\n",
|
276 |
+
" pbar = tqdm(range(num_images // jax.device_count()), desc='Generating Images', leave=True)\n",
|
277 |
" for i in pbar:\n",
|
278 |
" key, subkey = jax.random.split(key)\n",
|
279 |
" encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
|
|
|
289 |
" images_per_prompt_indices = np.asarray(range(0, len(images), batch_size))\n",
|
290 |
" clip_inputs['pixel_values'] = jnp.concatenate(list(clip_inputs['pixel_values'][images_per_prompt_indices + i] for i in range(batch_size)))\n",
|
291 |
" clip_inputs = shard(clip_inputs)\n",
|
292 |
+
" logits = p_clip(clip_inputs, clip_params)\n",
|
293 |
" logits = logits.reshape(-1, num_images)\n",
|
294 |
" top_scores = logits.argsort()[:, -top_k:][..., ::-1]\n",
|
295 |
" logits = jax.device_get(logits)\n",
|
|
|
325 |
" wandb.finish()\n",
|
326 |
" run = None # ensure we don't log on this run"
|
327 |
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": null,
|
332 |
+
"id": "415d3f54-7226-43de-9eea-4283a948dc93",
|
333 |
+
"metadata": {},
|
334 |
+
"outputs": [],
|
335 |
+
"source": []
|
336 |
}
|
337 |
],
|
338 |
"metadata": {
|