File size: 12,675 Bytes
6047b49
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
3b508e3
 
 
6047b49
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
 
 
 
 
 
 
 
6047b49
 
 
3b508e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6047b49
 
 
 
 
3b508e3
 
 
 
 
 
 
 
 
 
6047b49
 
 
3b508e3
6047b49
 
3b508e3
6047b49
 
 
 
3b508e3
6047b49
 
 
 
3b508e3
6047b49
 
 
3b508e3
 
 
 
 
 
 
 
6047b49
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
 
 
 
 
 
 
 
 
6047b49
3b508e3
 
6047b49
 
 
 
3b508e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6047b49
 
 
 
3b508e3
6047b49
 
3b508e3
 
 
 
 
 
 
 
 
 
 
6047b49
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
3b508e3
 
 
6047b49
 
 
 
3b508e3
 
 
6047b49
 
 
 
 
3b508e3
6047b49
 
 
 
3b508e3
6047b49
 
3b508e3
6047b49
 
 
 
3b508e3
6047b49
 
3b508e3
6047b49
 
 
 
3b508e3
 
 
 
 
 
 
 
 
 
6047b49
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
3b508e3
6047b49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d0b72877",
   "metadata": {},
   "source": [
    "# VQGAN JAX Encoding for `webdataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba7b31e6",
   "metadata": {},
   "source": [
    "This notebook shows how to pre-encode images to token sequences using JAX, VQGAN and a dataset in the [`webdataset` format](https://webdataset.github.io/webdataset/).\n",
    "\n",
    "This example uses a small subset of YFCC100M we created for testing, but it should be easy to adapt to any other image/caption dataset in the `webdataset` format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b59489e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torchvision.transforms as T\n",
    "import torchvision.transforms.functional as TF\n",
    "from torchvision.transforms import InterpolationMode\n",
    "import math\n",
    "\n",
    "import webdataset as wds\n",
    "\n",
    "import jax\n",
    "from jax import pmap"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c7c4c1e6",
   "metadata": {},
   "source": [
    "## Dataset and Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9822850f",
   "metadata": {},
   "source": [
    "The following is the list of shards we'll process. We hardcode the length of data so that we can see nice progress bars using `tqdm`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1265dbfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "shards = 'https://huggingface.co/datasets/dalle-mini/YFCC100M_OpenAI_subset/resolve/main/data/shard-{0000..0008}.tar'\n",
    "length = 8320"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e38fa14",
   "metadata": {},
   "source": [
    "If we are extra cautious or our server is unreliable, we can enable retries by providing a custom `curl` retrieval command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c8c5960",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Enable curl retries to try to work around temporary network / server errors.\n",
    "# This shouldn't be necessary when using reliable servers.\n",
    "# shards = f'pipe:curl -s --retry 5 --retry-delay 5 -L {shards} || true'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13c6631b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "# Output directory for encoded files\n",
    "encoded_output = Path.home()/'data'/'wds'/'encoded'\n",
    "\n",
    "batch_size = 128           # Per device\n",
    "num_workers = 8            # For parallel processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3435fb85",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = batch_size * jax.device_count()    # You can use a smaller size while testing\n",
    "batches = math.ceil(length / bs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88598e4b",
   "metadata": {},
   "source": [
    "Image processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "669b35df",
   "metadata": {},
   "outputs": [],
   "source": [
    "def center_crop(image, max_size=256):\n",
    "    # Note: we allow upscaling too. We should exclude small images.    \n",
    "    image = TF.resize(image, max_size, interpolation=InterpolationMode.LANCZOS)\n",
    "    image = TF.center_crop(image, output_size=2 * [max_size])\n",
    "    return image\n",
    "\n",
    "preprocess_image = T.Compose([\n",
    "    center_crop,\n",
    "    T.ToTensor(),\n",
    "    lambda t: t.permute(1, 2, 0)   # Reorder, we need dimensions last\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a185e90c",
   "metadata": {},
   "source": [
    "Caption preparation.\n",
    "\n",
    "Note that we receive the contents of the `json` structure, which will be replaced by the string we return.\n",
    "If we want to keep other fields inside `json`, we can add `caption` as a new field."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "423ee10e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_caption(item):\n",
    "    title = item['title_clean'].strip()\n",
    "    description = item['description_clean'].strip()\n",
    "    if len(title) > 0 and title[-1] not in '.!?': title += '.'\n",
    "    return f'{title} {description}'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d3a95db",
   "metadata": {},
   "source": [
    "When an error occurs (a download is disconnected, an image cannot be decoded, etc) the process stops with an exception. We can use one of the exception handlers provided by the `webdataset` library, such as `wds.warn_and_continue` or `wds.ignore_and_continue` to ignore the offending entry and keep iterating.\n",
    "\n",
    "**IMPORTANT WARNING:** Do not use error handlers to ignore exceptions until you have tested that your processing pipeline works fine. Otherwise, the process will continue trying to find a valid entry, and it will consume your whole dataset without doing any work.\n",
    "\n",
    "We can also create our custom exception handler as demonstrated here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369d9719",
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNUSED - Log exceptions to a file\n",
    "def ignore_and_log(exn):\n",
    "    with open('errors.txt', 'a') as f:\n",
    "        f.write(f'{repr(exn)}\\n')\n",
    "    return True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27de1414",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Or simply use `wds.ignore_and_continue`\n",
    "exception_handler = wds.warn_and_continue"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5149b6d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = wds.WebDataset(shards,\n",
    "                         length=batches,              # Hint so `len` is implemented\n",
    "                         shardshuffle=False,          # Keep same order for encoded files for easier bookkeeping. Set to `True` for training.\n",
    "                         handler=exception_handler,   # Ignore read errors instead of failing.\n",
    ")\n",
    "\n",
    "dataset = (dataset           \n",
    "      .decode('pil')                     # decode image with PIL\n",
    "#       .map_dict(jpg=preprocess_image, json=create_caption, handler=exception_handler)    # Process fields with functions defined above\n",
    "      .map_dict(jpg=preprocess_image, json=create_caption)    # Process fields with functions defined above\n",
    "      .to_tuple('__key__', 'jpg', 'json') # filter to keep only key (for reference), image, caption.\n",
    "      .batched(bs))                      # better to batch in the dataset (but we could also do it in the dataloader) - this arg does not affect speed and we could remove it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cac98cb",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "keys, images, captions = next(iter(dataset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd268fbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "images.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c24693c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "T.ToPILImage()(images[0].permute(2, 0, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "44d50a51",
   "metadata": {},
   "source": [
    "### Torch DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2df5e13",
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a354472b",
   "metadata": {},
   "source": [
    "## VQGAN-JAX model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fcf01d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vqgan_jax.modeling_flax_vqgan import VQModel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9daa636d",
   "metadata": {},
   "source": [
    "We'll use a VQGAN trained with Taming Transformers and converted to a JAX model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a8b818",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62ad01c3",
   "metadata": {},
   "source": [
    "## Encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20357f74",
   "metadata": {},
   "source": [
    "Encoding is really simple using `shard` to automatically distribute \"superbatches\" across devices, and `pmap`. This is all it takes to create our encoding function, that will be jitted on first use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6686b004",
   "metadata": {},
   "outputs": [],
   "source": [
    "from flax.training.common_utils import shard\n",
    "from functools import partial"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "322a4619",
   "metadata": {},
   "outputs": [],
   "source": [
    "@partial(jax.pmap, axis_name=\"batch\")\n",
    "def encode(batch):\n",
    "    # Not sure if we should `replicate` params, does not seem to have any effect\n",
    "    _, indices = model.encode(batch)\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14375a41",
   "metadata": {},
   "source": [
    "### Encoding loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff6c10d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "def encode_captioned_dataset(dataloader, output_dir, save_every=14):\n",
    "    output_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Saving strategy:\n",
    "    # - Create a new file every so often to prevent excessive file seeking.\n",
    "    # - Save each batch after processing.\n",
    "    # - Keep the file open until we are done with it.\n",
    "    file = None        \n",
    "    for n, (keys, images, captions) in enumerate(tqdm(dataloader)):\n",
    "        if (n % save_every == 0):\n",
    "            if file is not None:\n",
    "                file.close()\n",
    "            split_num = n // save_every\n",
    "            file = open(output_dir/f'split_{split_num:05x}.jsonl', 'w')\n",
    "\n",
    "        images = shard(images.numpy().squeeze())\n",
    "        encoded = encode(images)\n",
    "        encoded = encoded.reshape(-1, encoded.shape[-1])\n",
    "\n",
    "        encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
    "        batch_df = pd.DataFrame.from_dict({\"key\": keys, \"caption\": captions, \"encoding\": encoded_as_string})\n",
    "        batch_df.to_json(file, orient='records', lines=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09ff75a3",
   "metadata": {},
   "source": [
    "Create a new file every 318 iterations. This should produce splits of ~500 MB each, when using a total batch size of 1024."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96222bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_every = 318"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7704863d",
   "metadata": {},
   "outputs": [],
   "source": [
    "encode_captioned_dataset(dl, encoded_output, save_every=save_every)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8953dd84",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "db471c52d602b4f5f40ecaf278e88ccfef85c29d0a1a07185b0d51fc7acf4e26"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}