Add an evluation script
Browse files- run_medclip.sh +3 -3
- src/hybrid_clip/test_clip.ipynb +55 -47
run_medclip.sh
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
python src/hybrid_clip/run_hybrid_clip.py \
|
2 |
-
--output_dir ./snapshots/
|
3 |
-
--text_model_name_or_path="
|
4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
-
--tokenizer_name="
|
6 |
--train_file="data/train_dataset.json" \
|
7 |
--validation_file="data/valid_dataset.json" \
|
8 |
--do_train --do_eval \
|
|
|
1 |
python src/hybrid_clip/run_hybrid_clip.py \
|
2 |
+
--output_dir ./snapshots/vision_augmented_biobert \
|
3 |
+
--text_model_name_or_path="allenai/scibert_scivocab_uncased" \
|
4 |
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
5 |
+
--tokenizer_name="allenai/scibert_scivocab_uncased" \
|
6 |
--train_file="data/train_dataset.json" \
|
7 |
--validation_file="data/valid_dataset.json" \
|
8 |
--do_train --do_eval \
|
src/hybrid_clip/test_clip.ipynb
CHANGED
@@ -26,77 +26,85 @@
|
|
26 |
"cells": [
|
27 |
{
|
28 |
"cell_type": "code",
|
29 |
-
"execution_count":
|
30 |
-
"metadata": {},
|
31 |
-
"outputs": [],
|
32 |
"source": [
|
|
|
|
|
33 |
"from modeling_hybrid_clip import FlaxHybridCLIP\n",
|
34 |
-
"
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
-
"execution_count":
|
40 |
-
"metadata": {},
|
41 |
-
"outputs": [
|
42 |
-
{
|
43 |
-
"output_type": "stream",
|
44 |
-
"name": "stderr",
|
45 |
-
"text": [
|
46 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
47 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
48 |
-
"INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n"
|
49 |
-
]
|
50 |
-
}
|
51 |
-
],
|
52 |
"source": [
|
53 |
-
"
|
54 |
-
]
|
|
|
|
|
55 |
},
|
56 |
{
|
57 |
"cell_type": "code",
|
58 |
-
"execution_count":
|
59 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
61 |
"source": [
|
62 |
-
"
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
},
|
65 |
{
|
66 |
"cell_type": "code",
|
67 |
-
"execution_count":
|
68 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
"outputs": [
|
70 |
-
{
|
71 |
-
"output_type": "stream",
|
72 |
-
"name": "stderr",
|
73 |
-
"text": [
|
74 |
-
"INFO:absl:Starting the local TPU driver.\n",
|
75 |
-
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
|
76 |
-
"INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n"
|
77 |
-
]
|
78 |
-
},
|
79 |
{
|
80 |
"output_type": "execute_result",
|
81 |
"data": {
|
82 |
"text/plain": [
|
83 |
-
"[
|
84 |
]
|
85 |
},
|
86 |
"metadata": {},
|
87 |
-
"execution_count":
|
88 |
}
|
89 |
],
|
90 |
-
"
|
91 |
-
"jax.devices()"
|
92 |
-
]
|
93 |
-
},
|
94 |
-
{
|
95 |
-
"cell_type": "code",
|
96 |
-
"execution_count": null,
|
97 |
-
"metadata": {},
|
98 |
-
"outputs": [],
|
99 |
-
"source": []
|
100 |
}
|
101 |
]
|
102 |
}
|
|
|
26 |
"cells": [
|
27 |
{
|
28 |
"cell_type": "code",
|
29 |
+
"execution_count": 38,
|
|
|
|
|
30 |
"source": [
|
31 |
+
"from transformers import AutoTokenizer\n",
|
32 |
+
"from configuration_hybrid_clip import HybridCLIPConfig\n",
|
33 |
"from modeling_hybrid_clip import FlaxHybridCLIP\n",
|
34 |
+
" \n",
|
35 |
+
"# config = HybridCLIPConfig.from_pretrained(\"../..\")\n",
|
36 |
+
"model = FlaxHybridCLIP.from_pretrained(\"flax-community/medclip-roco\")"
|
37 |
+
],
|
38 |
+
"outputs": [],
|
39 |
+
"metadata": {}
|
40 |
},
|
41 |
{
|
42 |
"cell_type": "code",
|
43 |
+
"execution_count": 39,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
"source": [
|
45 |
+
"tokenizer = AutoTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\")"
|
46 |
+
],
|
47 |
+
"outputs": [],
|
48 |
+
"metadata": {}
|
49 |
},
|
50 |
{
|
51 |
"cell_type": "code",
|
52 |
+
"execution_count": 31,
|
53 |
+
"source": [
|
54 |
+
"import torch\n",
|
55 |
+
"import numpy as np\n",
|
56 |
+
"from run_hybrid_clip import Transform\n",
|
57 |
+
"from torchvision.transforms.functional import InterpolationMode\n",
|
58 |
+
"\n",
|
59 |
+
"image_size = 224\n",
|
60 |
+
"transforms = Transform(image_size)\n"
|
61 |
+
],
|
62 |
"outputs": [],
|
63 |
+
"metadata": {}
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": 40,
|
68 |
"source": [
|
69 |
+
"from torchvision.io.image import read_image, ImageReadMode\n",
|
70 |
+
"\n",
|
71 |
+
"# TODO create a batch of images\n",
|
72 |
+
"img = read_image('../../data/PMC4917066_amjcaserep-17-301-g001.jpg', mode=ImageReadMode.RGB)\n",
|
73 |
+
"tr_img = transforms(img)\n"
|
74 |
+
],
|
75 |
+
"outputs": [],
|
76 |
+
"metadata": {}
|
77 |
},
|
78 |
{
|
79 |
"cell_type": "code",
|
80 |
+
"execution_count": 37,
|
81 |
+
"source": [
|
82 |
+
"max_seq_length = 128\n",
|
83 |
+
"pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()\n",
|
84 |
+
"# pixel_values = torch.stack([example[0] for example in examples]).numpy()\n",
|
85 |
+
"captions = [example[1] for example in examples]\n",
|
86 |
+
"inputs = tokenizer(captions, max_length=max_seq_length, padding=\"max_length\", return_tensors=\"np\",\n",
|
87 |
+
" truncation=True)\n",
|
88 |
+
"batch = {\n",
|
89 |
+
" \"pixel_values\": pixel_values,\n",
|
90 |
+
" \"input_ids\": inputs[\"input_ids\"],\n",
|
91 |
+
" \"attention_mask\": inputs[\"attention_mask\"],\n",
|
92 |
+
" }\n",
|
93 |
+
"logits = model(**batch, train=False)[0]"
|
94 |
+
],
|
95 |
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
{
|
97 |
"output_type": "execute_result",
|
98 |
"data": {
|
99 |
"text/plain": [
|
100 |
+
"torch.Size([3, 224, 224])"
|
101 |
]
|
102 |
},
|
103 |
"metadata": {},
|
104 |
+
"execution_count": 37
|
105 |
}
|
106 |
],
|
107 |
+
"metadata": {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
}
|
109 |
]
|
110 |
}
|