boris commited on
Commit
35fe578
·
1 Parent(s): ebac379

feat: improve inference demo

Browse files
tools/inference/inference_pipeline.ipynb CHANGED
@@ -41,10 +41,10 @@
41
  "outputs": [],
42
  "source": [
43
  "# Install required libraries\n",
44
- "!pip install -q transformers\n",
45
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
46
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
47
- "!pip install -q wandb"
48
  ]
49
  },
50
  {
@@ -70,8 +70,8 @@
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
- "DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
74
- "DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
75
  "\n",
76
  "# VQGAN model\n",
77
  "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
@@ -91,13 +91,20 @@
91
  "import jax\n",
92
  "import jax.numpy as jnp\n",
93
  "\n",
 
 
 
 
 
 
 
 
 
 
94
  "# type used for computation - use bfloat16 on TPU's\n",
95
  "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
96
  "\n",
97
- "# TODO:\n",
98
- "# - we currently have an issue with model.generate() in bfloat16\n",
99
- "# - https://github.com/google/jax/pull/9089 should fix it\n",
100
- "# - remove below line and test on TPU with next release of JAX\n",
101
  "dtype = jnp.float32"
102
  ]
103
  },
@@ -115,35 +122,18 @@
115
  "outputs": [],
116
  "source": [
117
  "# Load models & tokenizer\n",
118
- "from dalle_mini.model import DalleBart\n",
119
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
120
- "from transformers import AutoTokenizer, CLIPProcessor, FlaxCLIPModel\n",
121
  "import wandb\n",
122
  "\n",
123
  "# Load dalle-mini\n",
124
- "if \":\" in DALLE_MODEL:\n",
125
- " # wandb artifact\n",
126
- " artifact = wandb.Api().artifact(DALLE_MODEL)\n",
127
- " # we only download required files (no need for opt_state which is large)\n",
128
- " model_files = [\n",
129
- " \"config.json\",\n",
130
- " \"flax_model.msgpack\",\n",
131
- " \"merges.txt\",\n",
132
- " \"special_tokens_map.json\",\n",
133
- " \"tokenizer.json\",\n",
134
- " \"tokenizer_config.json\",\n",
135
- " \"vocab.json\",\n",
136
- " ]\n",
137
- " for f in model_files:\n",
138
- " artifact.get_path(f).download(\"model\")\n",
139
- " model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
140
- " tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
141
- "else:\n",
142
- " # local folder or 🤗 Hub\n",
143
- " model = DalleBart.from_pretrained(\n",
144
- " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
145
- " )\n",
146
- " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
147
  "\n",
148
  "# Load VQGAN\n",
149
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
@@ -210,7 +200,8 @@
210
  " prng_key=key,\n",
211
  " params=params,\n",
212
  " top_k=top_k,\n",
213
- " top_p=top_p\n",
 
214
  " )\n",
215
  "\n",
216
  "\n",
@@ -233,7 +224,7 @@
233
  "id": "HmVN6IBwapBA"
234
  },
235
  "source": [
236
- "Keys are passed to the model on each device to generate unique inferences per device."
237
  ]
238
  },
239
  {
@@ -247,7 +238,7 @@
247
  "import random\n",
248
  "\n",
249
  "# create a random key\n",
250
- "seed = random.randint(0, 2 ** 32 - 1)\n",
251
  "key = jax.random.PRNGKey(seed)"
252
  ]
253
  },
@@ -299,7 +290,7 @@
299
  },
300
  "outputs": [],
301
  "source": [
302
- "prompt = \"a red T-shirt\""
303
  ]
304
  },
305
  {
@@ -316,27 +307,19 @@
316
  },
317
  {
318
  "cell_type": "markdown",
319
- "metadata": {
320
- "id": "iFVOyYboP0L-"
321
- },
322
  "source": [
323
- "We repeat the prompt on each device and tokenize it."
324
  ]
325
  },
326
  {
327
  "cell_type": "code",
328
  "execution_count": null,
329
- "metadata": {
330
- "id": "Rii_FJ7POw1y"
331
- },
332
  "outputs": [],
333
  "source": [
334
- "# repeat the prompt on each device\n",
335
- "repeated_prompts = [processed_prompt] * jax.device_count()\n",
336
- "\n",
337
- "# tokenize\n",
338
  "tokenized_prompt = tokenizer(\n",
339
- " repeated_prompts,\n",
340
  " return_tensors=\"jax\",\n",
341
  " padding=\"max_length\",\n",
342
  " truncation=True,\n",
@@ -360,24 +343,18 @@
360
  },
361
  {
362
  "cell_type": "markdown",
363
- "metadata": {
364
- "id": "2wiDtG3_SH2u"
365
- },
366
  "source": [
367
- "Finally we distribute the tokenized prompt onto the devices."
368
  ]
369
  },
370
  {
371
  "cell_type": "code",
372
  "execution_count": null,
373
- "metadata": {
374
- "id": "AImyrxHtR9TG"
375
- },
376
  "outputs": [],
377
  "source": [
378
- "from flax.training.common_utils import shard\n",
379
- "\n",
380
- "tokenized_prompt = shard(tokenized_prompt)"
381
  ]
382
  },
383
  {
@@ -455,6 +432,8 @@
455
  },
456
  "outputs": [],
457
  "source": [
 
 
458
  "# get clip scores\n",
459
  "clip_inputs = processor(\n",
460
  " text=[prompt] * jax.device_count(),\n",
 
41
  "outputs": [],
42
  "source": [
43
  "# Install required libraries\n",
44
+ "#!pip install -q transformers\n",
45
+ "#!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
46
+ "#!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
47
+ "#!pip install -q wandb"
48
  ]
49
  },
50
  {
 
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-mehdx7dg:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
74
+ "DALLE_COMMIT_ID = None\n",
75
  "\n",
76
  "# VQGAN model\n",
77
  "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
 
91
  "import jax\n",
92
  "import jax.numpy as jnp\n",
93
  "\n",
94
+ "# check how many devices are available\n",
95
+ "jax.local_device_count()"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
  "# type used for computation - use bfloat16 on TPU's\n",
105
  "dtype = jnp.bfloat16 if jax.local_device_count() == 8 else jnp.float32\n",
106
  "\n",
107
+ "# TODO: fix issue with bfloat16\n",
 
 
 
108
  "dtype = jnp.float32"
109
  ]
110
  },
 
122
  "outputs": [],
123
  "source": [
124
  "# Load models & tokenizer\n",
125
+ "from dalle_mini.model import DalleBart, DalleBartTokenizer\n",
126
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
127
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
128
  "import wandb\n",
129
  "\n",
130
  "# Load dalle-mini\n",
131
+ "model = DalleBart.from_pretrained(\n",
132
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
133
+ ")\n",
134
+ "tokenizer = DalleBartTokenizer.from_pretrained(\n",
135
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID\n",
136
+ ")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  "\n",
138
  "# Load VQGAN\n",
139
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
 
200
  " prng_key=key,\n",
201
  " params=params,\n",
202
  " top_k=top_k,\n",
203
+ " top_p=top_p,\n",
204
+ " max_length=257\n",
205
  " )\n",
206
  "\n",
207
  "\n",
 
224
  "id": "HmVN6IBwapBA"
225
  },
226
  "source": [
227
+ "Keys are passed to the model on each device to generate unique inference per device."
228
  ]
229
  },
230
  {
 
238
  "import random\n",
239
  "\n",
240
  "# create a random key\n",
241
+ "seed = random.randint(0, 2**32 - 1)\n",
242
  "key = jax.random.PRNGKey(seed)"
243
  ]
244
  },
 
290
  },
291
  "outputs": [],
292
  "source": [
293
+ "prompt = \"a waterfall under the sunset\""
294
  ]
295
  },
296
  {
 
307
  },
308
  {
309
  "cell_type": "markdown",
310
+ "metadata": {},
 
 
311
  "source": [
312
+ "We tokenize the prompt."
313
  ]
314
  },
315
  {
316
  "cell_type": "code",
317
  "execution_count": null,
318
+ "metadata": {},
 
 
319
  "outputs": [],
320
  "source": [
 
 
 
 
321
  "tokenized_prompt = tokenizer(\n",
322
+ " processed_prompt,\n",
323
  " return_tensors=\"jax\",\n",
324
  " padding=\"max_length\",\n",
325
  " truncation=True,\n",
 
343
  },
344
  {
345
  "cell_type": "markdown",
346
+ "metadata": {},
 
 
347
  "source": [
348
+ "Finally we replicate it onto each device."
349
  ]
350
  },
351
  {
352
  "cell_type": "code",
353
  "execution_count": null,
354
+ "metadata": {},
 
 
355
  "outputs": [],
356
  "source": [
357
+ "tokenized_prompt = replicate(tokenized_prompt)"
 
 
358
  ]
359
  },
360
  {
 
432
  },
433
  "outputs": [],
434
  "source": [
435
+ "from flax.training.common_utils import shard\n",
436
+ "\n",
437
  "# get clip scores\n",
438
  "clip_inputs = processor(\n",
439
  " text=[prompt] * jax.device_count(),\n",