Pedro Cuenca commited on
Commit
16f038a
·
1 Parent(s): 150ed18

* Notebook that processes CC12M and creates a version with encodings.

Browse files

The VQGAN in use was created by Boris Dayma:
https://huggingface.co/flax-community/vqgan_f16_16384. It was trained on
GPU using the Taming Transformers code and then converted to JAX.

The output file contains the following fields:
- `image_file`: relative path to the image file. To be preprended with
the root path where images reside.
- `caption`: the untransformed text caption.
- `encoding`: the encoding indices produced by the VQGAN, as a string
representation of a list with 256 integers.

encoding/vqgan-jax-encoding-with-captions.ipynb ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d0b72877",
6
+ "metadata": {},
7
+ "source": [
8
+ "# vqgan-jax-encoding-with-captions"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "875c82b3",
14
+ "metadata": {},
15
+ "source": [
16
+ "Notebook based on [vqgan-jax-reconstruction](https://colab.research.google.com/drive/1mdXXsMbV6K_LTvCh3IImRsFIWcKU5m1w?usp=sharing) by @surajpatil.\n",
17
+ "\n",
18
+ "We process a `tsv` file with `image_file` and `caption` fields, and add a `vqgan_indices` column with indices extracted from a VQGAN-JAX model."
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": 1,
24
+ "id": "3b59489e",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "import io\n",
29
+ "\n",
30
+ "import requests\n",
31
+ "from PIL import Image\n",
32
+ "import numpy as np\n",
33
+ "from tqdm import tqdm\n",
34
+ "\n",
35
+ "import torch\n",
36
+ "import torchvision.transforms as T\n",
37
+ "import torchvision.transforms.functional as TF\n",
38
+ "from torchvision.transforms import InterpolationMode\n",
39
+ "from torch.utils.data import Dataset, DataLoader\n",
40
+ "\n",
41
+ "import jax\n",
42
+ "from jax import pmap"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "id": "511c3b9e",
48
+ "metadata": {},
49
+ "source": [
50
+ "## VQGAN-JAX model"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "id": "bb408f6c",
56
+ "metadata": {},
57
+ "source": [
58
+ "`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities."
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 2,
64
+ "id": "2ca50dc7",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "markdown",
73
+ "id": "7b60da9a",
74
+ "metadata": {},
75
+ "source": [
76
+ "We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model."
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": 3,
82
+ "id": "29ce8b15",
83
+ "metadata": {},
84
+ "outputs": [
85
+ {
86
+ "data": {
87
+ "application/vnd.jupyter.widget-view+json": {
88
+ "model_id": "db406bdfc5d5428eaeae1631a04989dd",
89
+ "version_major": 2,
90
+ "version_minor": 0
91
+ },
92
+ "text/plain": [
93
+ "Downloading: 0%| | 0.00/433 [00:00<?, ?B/s]"
94
+ ]
95
+ },
96
+ "metadata": {},
97
+ "output_type": "display_data"
98
+ },
99
+ {
100
+ "data": {
101
+ "application/vnd.jupyter.widget-view+json": {
102
+ "model_id": "3e37f07fba6d48fca70313ae1fa8cc32",
103
+ "version_major": 2,
104
+ "version_minor": 0
105
+ },
106
+ "text/plain": [
107
+ "Downloading: 0%| | 0.00/304M [00:00<?, ?B/s]"
108
+ ]
109
+ },
110
+ "metadata": {},
111
+ "output_type": "display_data"
112
+ },
113
+ {
114
+ "name": "stderr",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "INFO:absl:Starting the local TPU driver.\n",
118
+ "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
119
+ "INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter Host TPU\n"
120
+ ]
121
+ },
122
+ {
123
+ "name": "stdout",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
127
+ ]
128
+ }
129
+ ],
130
+ "source": [
131
+ "model = VQModel.from_pretrained(\"flax-community/vqgan_f16_16384\")"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "id": "c7c4c1e6",
137
+ "metadata": {},
138
+ "source": [
139
+ "## Dataset"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "markdown",
144
+ "id": "7014a7ce",
145
+ "metadata": {},
146
+ "source": [
147
+ "We use Luke Melas-Kyriazi's `dataset.py` which reads image paths and captions from a tsv file that contains both. We only need the images for encoding."
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 4,
153
+ "id": "85832702",
154
+ "metadata": {},
155
+ "outputs": [],
156
+ "source": [
157
+ "from dalle_mini.dataset import *"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": 5,
163
+ "id": "81b19eca",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "cc12m_images = '/data/CC12M/images'\n",
168
+ "cc12m_list = '/data/CC12M/images-list-clean.tsv'\n",
169
+ "# cc12m_list = '/data/CC12M/images-10000.tsv'\n",
170
+ "cc12m_output = '/data/CC12M/images-encoded.tsv'"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 6,
176
+ "id": "fecc9a00",
177
+ "metadata": {},
178
+ "outputs": [],
179
+ "source": [
180
+ "image_size = 256\n",
181
+ "def image_transform(image):\n",
182
+ " s = min(image.size)\n",
183
+ " r = image_size / s\n",
184
+ " s = (round(r * image.size[1]), round(r * image.size[0]))\n",
185
+ " image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)\n",
186
+ " image = TF.center_crop(image, output_size = 2 * [image_size])\n",
187
+ " image = torch.unsqueeze(T.ToTensor()(image), 0)\n",
188
+ " image = image.permute(0, 2, 3, 1).numpy()\n",
189
+ " return image"
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": 7,
195
+ "id": "4ce2211f",
196
+ "metadata": {},
197
+ "outputs": [],
198
+ "source": [
199
+ "dataset = CaptionDataset(\n",
200
+ " images_root=cc12m_images,\n",
201
+ " captions_path=cc12m_list,\n",
202
+ " image_transform=image_transform,\n",
203
+ " image_transform_type='torchvision',\n",
204
+ " include_captions=False\n",
205
+ ")"
206
+ ]
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "execution_count": 8,
211
+ "id": "cc922704",
212
+ "metadata": {},
213
+ "outputs": [
214
+ {
215
+ "data": {
216
+ "text/plain": [
217
+ "8592141"
218
+ ]
219
+ },
220
+ "execution_count": 8,
221
+ "metadata": {},
222
+ "output_type": "execute_result"
223
+ }
224
+ ],
225
+ "source": [
226
+ "len(dataset)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "markdown",
231
+ "id": "62ad01c3",
232
+ "metadata": {},
233
+ "source": [
234
+ "## Encoding"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": 9,
240
+ "id": "88f36d0b",
241
+ "metadata": {},
242
+ "outputs": [],
243
+ "source": [
244
+ "def encode(model, batch):\n",
245
+ "# print(\"jitting encode function\")\n",
246
+ " _, indices = model.encode(batch)\n",
247
+ " return indices"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 10,
253
+ "id": "1f35f0cb",
254
+ "metadata": {},
255
+ "outputs": [],
256
+ "source": [
257
+ "def superbatch_generator(dataloader, num_tpus):\n",
258
+ " iter_loader = iter(dataloader)\n",
259
+ " for batch in iter_loader:\n",
260
+ " superbatch = [batch.squeeze(1)]\n",
261
+ " try:\n",
262
+ " for b in range(num_tpus-1):\n",
263
+ " batch = next(iter_loader)\n",
264
+ " if batch is None:\n",
265
+ " break\n",
266
+ " # Skip incomplete last batch\n",
267
+ " if batch.shape[0] == dataloader.batch_size:\n",
268
+ " superbatch.append(batch.squeeze(1))\n",
269
+ " except StopIteration:\n",
270
+ " pass\n",
271
+ " superbatch = torch.stack(superbatch, axis=0)\n",
272
+ " yield superbatch"
273
+ ]
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "execution_count": 11,
278
+ "id": "2210705b",
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "import os\n",
283
+ "\n",
284
+ "def encode_captioned_dataset(dataset, output_tsv, batch_size=32, num_workers=16):\n",
285
+ " if os.path.isfile(output_tsv):\n",
286
+ " print(f\"Destination file {output_tsv} already exists, please move away.\")\n",
287
+ " return\n",
288
+ " \n",
289
+ " num_tpus = 8 \n",
290
+ " dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)\n",
291
+ " superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)\n",
292
+ " \n",
293
+ " p_encoder = pmap(lambda batch: encode(model, batch))\n",
294
+ "\n",
295
+ " # We save each superbatch to avoid reallocation of buffers as we process them.\n",
296
+ " # We keep the file open to prevent excessive file seeks.\n",
297
+ " with open(output_tsv, \"w\") as file:\n",
298
+ " iterations = len(dataset) // (batch_size * num_tpus)\n",
299
+ " for n in tqdm(range(iterations)):\n",
300
+ " superbatch = next(superbatches)\n",
301
+ " encoded = p_encoder(superbatch.numpy())\n",
302
+ " encoded = encoded.reshape(-1, encoded.shape[-1])\n",
303
+ "\n",
304
+ " # Extract fields from the dataset internal `captions` property, and save to disk\n",
305
+ " start_index = n * batch_size * num_tpus\n",
306
+ " end_index = (n+1) * batch_size * num_tpus\n",
307
+ " paths = dataset.captions[\"image_file\"][start_index:end_index].values\n",
308
+ " captions = dataset.captions[\"caption\"][start_index:end_index].values\n",
309
+ " encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))\n",
310
+ " batch_df = pd.DataFrame.from_dict({\"image_file\": paths, \"caption\": captions, \"encoding\": encoded_as_string})\n",
311
+ " batch_df.to_csv(file, sep='\\t', header=(n==0), index=None)\n",
312
+ " "
313
+ ]
314
+ },
315
+ {
316
+ "cell_type": "code",
317
+ "execution_count": null,
318
+ "id": "7704863d",
319
+ "metadata": {},
320
+ "outputs": [
321
+ {
322
+ "name": "stderr",
323
+ "output_type": "stream",
324
+ "text": [
325
+ " 4%|██▋ | 621/16781 [07:09<3:02:46, 1.47it/s]"
326
+ ]
327
+ }
328
+ ],
329
+ "source": [
330
+ "encode_captioned_dataset(dataset, cc12m_output, batch_size=64, num_workers=16)"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "id": "8953dd84",
336
+ "metadata": {},
337
+ "source": [
338
+ "----"
339
+ ]
340
+ }
341
+ ],
342
+ "metadata": {
343
+ "kernelspec": {
344
+ "display_name": "Python 3 (ipykernel)",
345
+ "language": "python",
346
+ "name": "python3"
347
+ },
348
+ "language_info": {
349
+ "codemirror_mode": {
350
+ "name": "ipython",
351
+ "version": 3
352
+ },
353
+ "file_extension": ".py",
354
+ "mimetype": "text/x-python",
355
+ "name": "python",
356
+ "nbconvert_exporter": "python",
357
+ "pygments_lexer": "ipython3",
358
+ "version": "3.8.10"
359
+ }
360
+ },
361
+ "nbformat": 4,
362
+ "nbformat_minor": 5
363
+ }