Spaces:
Running
Running
feat(demo): uncomment pip install
Browse files
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -30,7 +30,7 @@
|
|
30 |
},
|
31 |
{
|
32 |
"cell_type": "code",
|
33 |
-
"execution_count":
|
34 |
"metadata": {
|
35 |
"colab": {
|
36 |
"base_uri": "https://localhost:8080/"
|
@@ -41,10 +41,10 @@
|
|
41 |
"outputs": [],
|
42 |
"source": [
|
43 |
"# Install required libraries\n",
|
44 |
-
"
|
45 |
-
"
|
46 |
-
"
|
47 |
-
"
|
48 |
]
|
49 |
},
|
50 |
{
|
@@ -61,7 +61,7 @@
|
|
61 |
},
|
62 |
{
|
63 |
"cell_type": "code",
|
64 |
-
"execution_count":
|
65 |
"metadata": {
|
66 |
"id": "K6CxW2o42f-w"
|
67 |
},
|
@@ -84,9 +84,28 @@
|
|
84 |
},
|
85 |
{
|
86 |
"cell_type": "code",
|
87 |
-
"execution_count":
|
88 |
"metadata": {},
|
89 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
"source": [
|
91 |
"import jax\n",
|
92 |
"import jax.numpy as jnp\n",
|
|
|
30 |
},
|
31 |
{
|
32 |
"cell_type": "code",
|
33 |
+
"execution_count": 1,
|
34 |
"metadata": {
|
35 |
"colab": {
|
36 |
"base_uri": "https://localhost:8080/"
|
|
|
41 |
"outputs": [],
|
42 |
"source": [
|
43 |
"# Install required libraries\n",
|
44 |
+
"!pip install -q transformers\n",
|
45 |
+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
46 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git\n",
|
47 |
+
"!pip install -q wandb"
|
48 |
]
|
49 |
},
|
50 |
{
|
|
|
61 |
},
|
62 |
{
|
63 |
"cell_type": "code",
|
64 |
+
"execution_count": 2,
|
65 |
"metadata": {
|
66 |
"id": "K6CxW2o42f-w"
|
67 |
},
|
|
|
84 |
},
|
85 |
{
|
86 |
"cell_type": "code",
|
87 |
+
"execution_count": 3,
|
88 |
"metadata": {},
|
89 |
+
"outputs": [
|
90 |
+
{
|
91 |
+
"ename": "KeyboardInterrupt",
|
92 |
+
"evalue": "",
|
93 |
+
"output_type": "error",
|
94 |
+
"traceback": [
|
95 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
96 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
97 |
+
"Input \u001b[0;32mIn [3]\u001b[0m, in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mjax\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mjnp\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# check how many devices are available\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m \u001b[43mjax\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlocal_device_count\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
98 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:330\u001b[0m, in \u001b[0;36mlocal_device_count\u001b[0;34m(backend)\u001b[0m\n\u001b[1;32m 328\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlocal_device_count\u001b[39m(backend: Optional[Union[\u001b[38;5;28mstr\u001b[39m, XlaBackend]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28mint\u001b[39m:\n\u001b[1;32m 329\u001b[0m \u001b[38;5;124;03m\"\"\"Returns the number of devices addressable by this process.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 330\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mint\u001b[39m(\u001b[43mget_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbackend\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mlocal_device_count())\n",
|
99 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:298\u001b[0m, in \u001b[0;36mget_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[38;5;129m@lru_cache\u001b[39m(maxsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m) \u001b[38;5;66;03m# don't use util.memoize because there is no X64 dependence.\u001b[39;00m\n\u001b[1;32m 297\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_backend\u001b[39m(platform\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[0;32m--> 298\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_get_backend_uncached\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n",
|
100 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:281\u001b[0m, in \u001b[0;36m_get_backend_uncached\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(platform, (\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28;01mNone\u001b[39;00m), \u001b[38;5;28mstr\u001b[39m)):\n\u001b[1;32m 279\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m platform\n\u001b[0;32m--> 281\u001b[0m bs \u001b[38;5;241m=\u001b[39m \u001b[43mbackends\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 282\u001b[0m platform \u001b[38;5;241m=\u001b[39m (platform \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_xla_backend \u001b[38;5;129;01mor\u001b[39;00m FLAGS\u001b[38;5;241m.\u001b[39mjax_platform_name\n\u001b[1;32m 283\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 284\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m platform \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
101 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:231\u001b[0m, in \u001b[0;36mbackends\u001b[0;34m()\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m platform, priority \u001b[38;5;129;01min\u001b[39;00m platforms_and_priorites:\n\u001b[1;32m 230\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 231\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43m_init_backend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mplatform\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 232\u001b[0m _backends[platform] \u001b[38;5;241m=\u001b[39m backend\n\u001b[1;32m 233\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m priority \u001b[38;5;241m>\u001b[39m default_priority:\n",
|
102 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:260\u001b[0m, in \u001b[0;36m_init_backend\u001b[0;34m(platform)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnknown backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mplatform\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 259\u001b[0m logging\u001b[38;5;241m.\u001b[39mvlog(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mInitializing backend \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m%\u001b[39m platform)\n\u001b[0;32m--> 260\u001b[0m backend \u001b[38;5;241m=\u001b[39m \u001b[43mfactory\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 261\u001b[0m \u001b[38;5;66;03m# TODO(skye): consider raising more descriptive errors directly from backend\u001b[39;00m\n\u001b[1;32m 262\u001b[0m \u001b[38;5;66;03m# factories instead of returning None.\u001b[39;00m\n\u001b[1;32m 263\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m backend \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
103 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jax/_src/lib/xla_bridge.py:170\u001b[0m, in \u001b[0;36mtpu_client_timer_callback\u001b[0;34m(timer_secs)\u001b[0m\n\u001b[1;32m 167\u001b[0m t\u001b[38;5;241m.\u001b[39mstart()\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 170\u001b[0m client \u001b[38;5;241m=\u001b[39m \u001b[43mxla_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmake_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 172\u001b[0m t\u001b[38;5;241m.\u001b[39mcancel()\n",
|
104 |
+
"File \u001b[0;32m~/.pyenv/versions/3.9.7/envs/dev/lib/python3.9/site-packages/jaxlib/xla_client.py:96\u001b[0m, in \u001b[0;36mmake_tpu_client\u001b[0;34m()\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mmake_tpu_client\u001b[39m():\n\u001b[0;32m---> 96\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_xla\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tpu_client\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmax_inflight_computations\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m32\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
105 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
106 |
+
]
|
107 |
+
}
|
108 |
+
],
|
109 |
"source": [
|
110 |
"import jax\n",
|
111 |
"import jax.numpy as jnp\n",
|