Spaces:
Running
Running
feat: create a table
Browse files
dev/inference/wandb-backend.ipynb
CHANGED
@@ -46,7 +46,8 @@
|
|
46 |
"batch_size = 8\n",
|
47 |
"num_images = 128\n",
|
48 |
"top_k = 8\n",
|
49 |
-
"text_normalizer = TextNormalizer() if normalize_text else None"
|
|
|
50 |
]
|
51 |
},
|
52 |
{
|
@@ -95,8 +96,8 @@
|
|
95 |
" samples = []\n",
|
96 |
" for row in reader:\n",
|
97 |
" samples.append(row)\n",
|
98 |
-
" # make list multiple of batch_size by adding \
|
99 |
-
" samples_to_add = [{'Caption':
|
100 |
" samples.extend(samples_to_add)\n",
|
101 |
" # reshape\n",
|
102 |
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
@@ -388,7 +389,6 @@
|
|
388 |
" def p_clip(inputs):\n",
|
389 |
" logits = clip(**inputs).logits_per_image\n",
|
390 |
" return logits\n",
|
391 |
-
" scores = jax.nn.softmax(logits, axis=0).squeeze() \n",
|
392 |
" \n",
|
393 |
" functions_pmapped = False"
|
394 |
]
|
@@ -649,7 +649,8 @@
|
|
649 |
"outputs": [],
|
650 |
"source": [
|
651 |
"results = []\n",
|
652 |
-
"columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]"
|
|
|
653 |
]
|
654 |
},
|
655 |
{
|
@@ -660,12 +661,23 @@
|
|
660 |
"outputs": [],
|
661 |
"source": [
|
662 |
"for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
|
|
|
663 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
664 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
665 |
-
" top_scores = [
|
666 |
" results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
|
667 |
]
|
668 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
669 |
{
|
670 |
"cell_type": "code",
|
671 |
"execution_count": null,
|
|
|
46 |
"batch_size = 8\n",
|
47 |
"num_images = 128\n",
|
48 |
"top_k = 8\n",
|
49 |
+
"text_normalizer = TextNormalizer() if normalize_text else None\n",
|
50 |
+
"padding_item = 'NONE'"
|
51 |
]
|
52 |
},
|
53 |
{
|
|
|
96 |
" samples = []\n",
|
97 |
" for row in reader:\n",
|
98 |
" samples.append(row)\n",
|
99 |
+
" # make list multiple of batch_size by adding elements\n",
|
100 |
+
" samples_to_add = [{'Caption':padding_item, 'Theme':padding_item}] * (-len(samples) % batch_size)\n",
|
101 |
" samples.extend(samples_to_add)\n",
|
102 |
" # reshape\n",
|
103 |
" samples = [samples[i:i+batch_size] for i in range(0, len(samples), batch_size)]"
|
|
|
389 |
" def p_clip(inputs):\n",
|
390 |
" logits = clip(**inputs).logits_per_image\n",
|
391 |
" return logits\n",
|
|
|
392 |
" \n",
|
393 |
" functions_pmapped = False"
|
394 |
]
|
|
|
649 |
"outputs": [],
|
650 |
"source": [
|
651 |
"results = []\n",
|
652 |
+
"columns = ['Caption', 'Theme'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
|
653 |
+
"logits = jax.device_get(logits)"
|
654 |
]
|
655 |
},
|
656 |
{
|
|
|
661 |
"outputs": [],
|
662 |
"source": [
|
663 |
"for i, (idx, scores, sample) in enumerate(zip(top_idx, logits, batch)):\n",
|
664 |
+
" if sample['Caption'] == padding_item: continue\n",
|
665 |
" cur_images = [images[x] for x in images_per_prompt_indices + i]\n",
|
666 |
" top_images = [wandb.Image(cur_images[x]) for x in idx]\n",
|
667 |
+
" top_scores = [scores[x] for x in idx]\n",
|
668 |
" results.append([sample['Caption'], sample['Theme']] + top_images + top_scores)"
|
669 |
]
|
670 |
},
|
671 |
+
{
|
672 |
+
"cell_type": "code",
|
673 |
+
"execution_count": null,
|
674 |
+
"id": "4bf40461-99d3-4d36-b7cc-e0129a3c9053",
|
675 |
+
"metadata": {},
|
676 |
+
"outputs": [],
|
677 |
+
"source": [
|
678 |
+
"table = wandb.Table(columns=columns, data=results)"
|
679 |
+
]
|
680 |
+
},
|
681 |
{
|
682 |
"cell_type": "code",
|
683 |
"execution_count": null,
|