pcuenca commited on
Commit
eb591ff
2 Parent(s): 648e404 2c2f570

Merge pull request #23 from khalidsaifullaah/main

Browse files
encoding/vqgan-jax-encoding-yfcc100m-splitted.ipynb DELETED
@@ -1,462 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "id": "d0b72877",
6
- "metadata": {},
7
- "source": [
8
- "# vqgan-jax-encoding-yfcc100m"
9
- ]
10
- },
11
- {
12
- "cell_type": "markdown",
13
- "id": "747733a4",
14
- "metadata": {},
15
- "source": [
16
- "Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.\n",
17
- "\n",
18
- "This dataset was prepared by @borisdayma in Json lines format."
19
- ]
20
- },
21
- {
22
- "cell_type": "code",
23
- "execution_count": 1,
24
- "id": "3b59489e",
25
- "metadata": {},
26
- "outputs": [],
27
- "source": [
28
- "import io\n",
29
- "\n",
30
- "import requests\n",
31
- "from PIL import Image\n",
32
- "import numpy as np\n",
33
- "from tqdm import tqdm\n",
34
- "\n",
35
- "import torch\n",
36
- "import torchvision.transforms as T\n",
37
- "import torchvision.transforms.functional as TF\n",
38
- "from torchvision.transforms import InterpolationMode\n",
39
- "from torch.utils.data import Dataset, DataLoader\n",
40
- "from torchvision.datasets.folder import default_loader\n",
41
- "\n",
42
- "import jax\n",
43
- "from jax import pmap"
44
- ]
45
- },
46
- {
47
- "cell_type": "markdown",
48
- "id": "511c3b9e",
49
- "metadata": {},
50
- "source": [
51
- "## VQGAN-JAX model"
52
- ]
53
- },
54
- {
55
- "cell_type": "markdown",
56
- "id": "bb408f6c",
57
- "metadata": {},
58
- "source": [
59
- "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
60
- ]
61
- },
62
- {
63
- "cell_type": "code",
64
- "execution_count": 2,
65
- "id": "2ca50dc7",
66
- "metadata": {},
67
- "outputs": [],
68
- "source": [
69
- "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
70
- ]
71
- },
72
- {
73
- "cell_type": "markdown",
74
- "id": "7b60da9a",
75
- "metadata": {},
76
- "source": [
77
- "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
78
- ]
79
- },
80
- {
81
- "cell_type": "code",
82
- "execution_count": 4,
83
- "id": "29ce8b15",
84
- "metadata": {},
85
- "outputs": [],
86
- "source": [
87
- "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
88
- ]
89
- },
90
- {
91
- "cell_type": "markdown",
92
- "id": "c7c4c1e6",
93
- "metadata": {},
94
- "source": [
95
- "## Dataset"
96
- ]
97
- },
98
- {
99
- "cell_type": "markdown",
100
- "id": "fd4c608e",
101
- "metadata": {},
102
- "source": [
103
- "I splitted the files to do the process iteratively. Pandas struggles with memory and `datasets` has problems when filtering files, as described [in this issue](https://github.com/huggingface/datasets/issues/2644)."
104
- ]
105
- },
106
- {
107
- "cell_type": "code",
108
- "execution_count": 5,
109
- "id": "6c058636",
110
- "metadata": {},
111
- "outputs": [],
112
- "source": [
113
- "import pandas as pd\n",
114
- "from pathlib import Path"
115
- ]
116
- },
117
- {
118
- "cell_type": "code",
119
- "execution_count": 6,
120
- "id": "81b19eca",
121
- "metadata": {},
122
- "outputs": [],
123
- "source": [
124
- "yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')\n",
125
- "# Images are 'sharded' from the following directory\n",
126
- "yfcc100m_images = yfcc100m/'data'/'images'\n",
127
- "yfcc100m_metadata_splits = yfcc100m/'metadata_splitted'\n",
128
- "yfcc100m_output = yfcc100m/'metadata_encoded'"
129
- ]
130
- },
131
- {
132
- "cell_type": "code",
133
- "execution_count": 7,
134
- "id": "40873de9",
135
- "metadata": {},
136
- "outputs": [
137
- {
138
- "data": {
139
- "text/plain": [
140
- "[PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04'),\n",
141
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25'),\n",
142
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_17'),\n",
143
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_10'),\n",
144
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_22'),\n",
145
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_28'),\n",
146
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_09'),\n",
147
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_03'),\n",
148
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_07'),\n",
149
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_26'),\n",
150
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_14'),\n",
151
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_19'),\n",
152
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_13'),\n",
153
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_21'),\n",
154
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_00'),\n",
155
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_02'),\n",
156
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_08'),\n",
157
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_11'),\n",
158
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_29'),\n",
159
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_23'),\n",
160
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_24'),\n",
161
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_16'),\n",
162
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_05'),\n",
163
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_01'),\n",
164
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_12'),\n",
165
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_18'),\n",
166
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_20'),\n",
167
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_27'),\n",
168
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_15'),\n",
169
- " PosixPath('/sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_06')]"
170
- ]
171
- },
172
- "execution_count": 7,
173
- "metadata": {},
174
- "output_type": "execute_result"
175
- }
176
- ],
177
- "source": [
178
- "all_splits = [x for x in yfcc100m_metadata_splits.iterdir() if x.is_file()]\n",
179
- "all_splits"
180
- ]
181
- },
182
- {
183
- "cell_type": "markdown",
184
- "id": "f604e3c9",
185
- "metadata": {},
186
- "source": [
187
- "### Cleanup"
188
- ]
189
- },
190
- {
191
- "cell_type": "code",
192
- "execution_count": 8,
193
- "id": "dea06b92",
194
- "metadata": {},
195
- "outputs": [],
196
- "source": [
197
- "def image_exists(root: str, name: str, ext: str):\n",
198
- " image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)\n",
199
- " return image_path.exists()"
200
- ]
201
- },
202
- {
203
- "cell_type": "code",
204
- "execution_count": 9,
205
- "id": "1d34d7aa",
206
- "metadata": {},
207
- "outputs": [],
208
- "source": [
209
- "class YFC100Dataset(Dataset):\n",
210
- " def __init__(self, image_list: pd.DataFrame, images_root: str, image_size: int, max_items=None):\n",
211
- " \"\"\"\n",
212
- " :param image_list: DataFrame with clean entries - all images must exist.\n",
213
- " :param images_root: Root directory containing the images\n",
214
- " :param image_size: Image size. Source images will be resized and center-cropped.\n",
215
- " :max_items: Limit dataset size for debugging\n",
216
- " \"\"\"\n",
217
- " self.image_list = image_list\n",
218
- " self.images_root = Path(images_root)\n",
219
- " if max_items is not None: self.image_list = self.image_list[:max_items]\n",
220
- " self.image_size = image_size\n",
221
- " \n",
222
- " def __len__(self):\n",
223
- " return len(self.image_list)\n",
224
- " \n",
225
- " def _get_raw_image(self, i):\n",
226
- " image_name = self.image_list.iloc[0].key\n",
227
- " image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')\n",
228
- " return default_loader(image_path)\n",
229
- " \n",
230
- " def resize_image(self, image):\n",
231
- " s = min(image.size)\n",
232
- " r = self.image_size / s\n",
233
- " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
234
- " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
235
- " image = TF.center_crop(image, output_size = 2 * [self.image_size])\n",
236
- " # FIXME: np.array is necessary in my installation, but it should be automatic\n",
237
- " image = torch.unsqueeze(T.ToTensor()(np.array(image)), 0)\n",
238
- " image = image.permute(0, 2, 3, 1).numpy()\n",
239
- " return image\n",
240
- " \n",
241
- " def __getitem__(self, i):\n",
242
- " image = self._get_raw_image(i)\n",
243
- " image = self.resize_image(image)\n",
244
- " # Just return the image, not the caption\n",
245
- " return image"
246
- ]
247
- },
248
- {
249
- "cell_type": "markdown",
250
- "id": "62ad01c3",
251
- "metadata": {},
252
- "source": [
253
- "## Encoding"
254
- ]
255
- },
256
- {
257
- "cell_type": "code",
258
- "execution_count": 10,
259
- "id": "88f36d0b",
260
- "metadata": {},
261
- "outputs": [],
262
- "source": [
263
- "def encode(model, batch):\n",
264
- " print(\"jitting encode function\")\n",
265
- " _, indices = model.encode(batch)\n",
266
- "\n",
267
- "# # FIXME: The model does not run in my computer (no cudNN currently installed) - faking it\n",
268
- "# indices = np.random.randint(0, 16384, (batch.shape[0], 256))\n",
269
- " return indices"
270
- ]
271
- },
272
- {
273
- "cell_type": "code",
274
- "execution_count": null,
275
- "id": "d1f45dd8",
276
- "metadata": {},
277
- "outputs": [],
278
- "source": [
279
- "#FIXME\n",
280
- "# import random\n",
281
- "# model = {}"
282
- ]
283
- },
284
- {
285
- "cell_type": "code",
286
- "execution_count": 11,
287
- "id": "1f35f0cb",
288
- "metadata": {},
289
- "outputs": [],
290
- "source": [
291
- "from flax.training.common_utils import shard\n",
292
- "\n",
293
- "def superbatch_generator(dataloader):\n",
294
- " iter_loader = iter(dataloader)\n",
295
- " for batch in iter_loader:\n",
296
- " batch = batch.squeeze(1)\n",
297
- " # Skip incomplete last batch\n",
298
- " if batch.shape[0] == dataloader.batch_size:\n",
299
- " yield shard(batch)"
300
- ]
301
- },
302
- {
303
- "cell_type": "code",
304
- "execution_count": 13,
305
- "id": "2210705b",
306
- "metadata": {},
307
- "outputs": [],
308
- "source": [
309
- "import os\n",
310
- "import jax\n",
311
- "\n",
312
- "def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):\n",
313
- " if os.path.isfile(output_jsonl):\n",
314
- " print(f\"Destination file {output_jsonl} already exists, please move away.\")\n",
315
- " return\n",
316
- " \n",
317
- " num_tpus = jax.device_count()\n",
318
- " dataloader = DataLoader(dataset, batch_size=num_tpus*batch_size, num_workers=num_workers)\n",
319
- " superbatches = superbatch_generator(dataloader)\n",
320
- " \n",
321
- " p_encoder = pmap(lambda batch: encode(model, batch))\n",
322
- "\n",
323
- " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
324
- " # We keep the file open to prevent excessive file seeks.\n",
325
- " with open(output_jsonl, \"w\") as file:\n",
326
- " iterations = len(dataset) // (batch_size * num_tpus)\n",
327
- " for n in tqdm(range(iterations)):\n",
328
- " superbatch = next(superbatches)\n",
329
- " encoded = p_encoder(superbatch.numpy())\n",
330
- " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
331
- "\n",
332
- " # Extract fields from the dataset internal `image_list` property, and save to disk\n",
333
- " # We need to read from the df because the Dataset only returns images\n",
334
- " start_index = n * batch_size * num_tpus\n",
335
- " end_index = (n+1) * batch_size * num_tpus\n",
336
- " keys = dataset.image_list[\"key\"][start_index:end_index].values\n",
337
- " captions = dataset.image_list[\"caption\"][start_index:end_index].values\n",
338
- "# encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
339
- " batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded})\n",
340
- " batch_df.to_json(file, orient='records', lines=True)"
341
- ]
342
- },
343
- {
344
- "cell_type": "code",
345
- "execution_count": 14,
346
- "id": "7704863d",
347
- "metadata": {},
348
- "outputs": [
349
- {
350
- "name": "stdout",
351
- "output_type": "stream",
352
- "text": [
353
- "Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_04\n",
354
- "54024 selected from 500000 total entries\n"
355
- ]
356
- },
357
- {
358
- "name": "stderr",
359
- "output_type": "stream",
360
- "text": [
361
- "INFO:absl:Starting the local TPU driver.\n",
362
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
363
- "INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n",
364
- " 0%| | 0/31 [00:00<?, ?it/s]"
365
- ]
366
- },
367
- {
368
- "name": "stdout",
369
- "output_type": "stream",
370
- "text": [
371
- "jitting encode function\n"
372
- ]
373
- },
374
- {
375
- "name": "stderr",
376
- "output_type": "stream",
377
- "text": [
378
- "100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:02<00:00, 10.61it/s]\n"
379
- ]
380
- },
381
- {
382
- "name": "stdout",
383
- "output_type": "stream",
384
- "text": [
385
- "Processing /sddata/dalle-mini/YFCC100M_OpenAI_subset/metadata_splitted/metadata_split_25\n",
386
- "99530 selected from 500000 total entries\n"
387
- ]
388
- },
389
- {
390
- "name": "stderr",
391
- "output_type": "stream",
392
- "text": [
393
- " 3%|██▌ | 1/31 [00:01<00:53, 1.79s/it]"
394
- ]
395
- },
396
- {
397
- "name": "stdout",
398
- "output_type": "stream",
399
- "text": [
400
- "jitting encode function\n"
401
- ]
402
- },
403
- {
404
- "name": "stderr",
405
- "output_type": "stream",
406
- "text": [
407
- "100%|███████████████████████████████████████████████████████████████████████████████| 31/31 [00:03<00:00, 9.92it/s]\n"
408
- ]
409
- }
410
- ],
411
- "source": [
412
- "for split in all_splits:\n",
413
- " print(f\"Processing {split}\")\n",
414
- " df = pd.read_json(split, orient=\"records\", lines=True)\n",
415
- " df['image_exists'] = df.apply(lambda row: image_exists(yfcc100m_images, row['key'], '.' + row['ext']), axis=1)\n",
416
- " print(f\"{len(df[df.image_exists])} selected from {len(df)} total entries\")\n",
417
- " \n",
418
- " df = df[df.image_exists]\n",
419
- " captions = df.apply(lambda row: ' '.join([row[\"title_clean\"], row[\"description_clean\"]]), axis=1)\n",
420
- " df[\"caption\"] = captions.values\n",
421
- " \n",
422
- " dataset = YFC100Dataset(\n",
423
- " image_list = df,\n",
424
- " images_root = yfcc100m_images,\n",
425
- " image_size = 256,\n",
426
- "# max_items = 2000,\n",
427
- " )\n",
428
- " \n",
429
- " encode_captioned_dataset(dataset, yfcc100m_output/split.name, batch_size=64, num_workers=16)"
430
- ]
431
- },
432
- {
433
- "cell_type": "markdown",
434
- "id": "8953dd84",
435
- "metadata": {},
436
- "source": [
437
- "----"
438
- ]
439
- }
440
- ],
441
- "metadata": {
442
- "kernelspec": {
443
- "display_name": "Python 3 (ipykernel)",
444
- "language": "python",
445
- "name": "python3"
446
- },
447
- "language_info": {
448
- "codemirror_mode": {
449
- "name": "ipython",
450
- "version": 3
451
- },
452
- "file_extension": ".py",
453
- "mimetype": "text/x-python",
454
- "name": "python",
455
- "nbconvert_exporter": "python",
456
- "pygments_lexer": "ipython3",
457
- "version": "3.8.10"
458
- }
459
- },
460
- "nbformat": 4,
461
- "nbformat_minor": 5
462
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
encoding/vqgan-jax-encoding-yfcc100m.ipynb CHANGED
The diff for this file is too large to render. See raw diff