kaushalya commited on
Commit
1a914c5
·
1 Parent(s): a95325e

Add an evluation script

Browse files
Files changed (2) hide show
  1. run_medclip.sh +3 -3
  2. src/hybrid_clip/test_clip.ipynb +55 -47
run_medclip.sh CHANGED
@@ -1,8 +1,8 @@
1
  python src/hybrid_clip/run_hybrid_clip.py \
2
- --output_dir ./snapshots/vision_augmented \
3
- --text_model_name_or_path="roberta-base" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
- --tokenizer_name="roberta-base" \
6
  --train_file="data/train_dataset.json" \
7
  --validation_file="data/valid_dataset.json" \
8
  --do_train --do_eval \
 
1
  python src/hybrid_clip/run_hybrid_clip.py \
2
+ --output_dir ./snapshots/vision_augmented_biobert \
3
+ --text_model_name_or_path="allenai/scibert_scivocab_uncased" \
4
  --vision_model_name_or_path="openai/clip-vit-base-patch32" \
5
+ --tokenizer_name="allenai/scibert_scivocab_uncased" \
6
  --train_file="data/train_dataset.json" \
7
  --validation_file="data/valid_dataset.json" \
8
  --do_train --do_eval \
src/hybrid_clip/test_clip.ipynb CHANGED
@@ -26,77 +26,85 @@
26
  "cells": [
27
  {
28
  "cell_type": "code",
29
- "execution_count": 1,
30
- "metadata": {},
31
- "outputs": [],
32
  "source": [
 
 
33
  "from modeling_hybrid_clip import FlaxHybridCLIP\n",
34
- "import jax\n"
35
- ]
 
 
 
 
36
  },
37
  {
38
  "cell_type": "code",
39
- "execution_count": 2,
40
- "metadata": {},
41
- "outputs": [
42
- {
43
- "output_type": "stream",
44
- "name": "stderr",
45
- "text": [
46
- "INFO:absl:Starting the local TPU driver.\n",
47
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
48
- "INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n"
49
- ]
50
- }
51
- ],
52
  "source": [
53
- "model = FlaxHybridCLIP.from_text_vision_pretrained(\"bert-base-uncased\", \"openai/clip-vit-base-patch32\")"
54
- ]
 
 
55
  },
56
  {
57
  "cell_type": "code",
58
- "execution_count": null,
59
- "metadata": {},
 
 
 
 
 
 
 
 
60
  "outputs": [],
 
 
 
 
 
61
  "source": [
62
- "model = FlaxHybridCLIP.from_text_vision_pretrained(\"bert-base-uncased\", \"openai/clip-vit-base-patch32\", text_from_pt=True, vision_from_pt=True)"
63
- ]
 
 
 
 
 
 
64
  },
65
  {
66
  "cell_type": "code",
67
- "execution_count": 3,
68
- "metadata": {},
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  "outputs": [
70
- {
71
- "output_type": "stream",
72
- "name": "stderr",
73
- "text": [
74
- "INFO:absl:Starting the local TPU driver.\n",
75
- "INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
76
- "INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.\n"
77
- ]
78
- },
79
  {
80
  "output_type": "execute_result",
81
  "data": {
82
  "text/plain": [
83
- "[GpuDevice(id=0, process_index=0)]"
84
  ]
85
  },
86
  "metadata": {},
87
- "execution_count": 3
88
  }
89
  ],
90
- "source": [
91
- "jax.devices()"
92
- ]
93
- },
94
- {
95
- "cell_type": "code",
96
- "execution_count": null,
97
- "metadata": {},
98
- "outputs": [],
99
- "source": []
100
  }
101
  ]
102
  }
 
26
  "cells": [
27
  {
28
  "cell_type": "code",
29
+ "execution_count": 38,
 
 
30
  "source": [
31
+ "from transformers import AutoTokenizer\n",
32
+ "from configuration_hybrid_clip import HybridCLIPConfig\n",
33
  "from modeling_hybrid_clip import FlaxHybridCLIP\n",
34
+ " \n",
35
+ "# config = HybridCLIPConfig.from_pretrained(\"../..\")\n",
36
+ "model = FlaxHybridCLIP.from_pretrained(\"flax-community/medclip-roco\")"
37
+ ],
38
+ "outputs": [],
39
+ "metadata": {}
40
  },
41
  {
42
  "cell_type": "code",
43
+ "execution_count": 39,
 
 
 
 
 
 
 
 
 
 
 
 
44
  "source": [
45
+ "tokenizer = AutoTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\")"
46
+ ],
47
+ "outputs": [],
48
+ "metadata": {}
49
  },
50
  {
51
  "cell_type": "code",
52
+ "execution_count": 31,
53
+ "source": [
54
+ "import torch\n",
55
+ "import numpy as np\n",
56
+ "from run_hybrid_clip import Transform\n",
57
+ "from torchvision.transforms.functional import InterpolationMode\n",
58
+ "\n",
59
+ "image_size = 224\n",
60
+ "transforms = Transform(image_size)\n"
61
+ ],
62
  "outputs": [],
63
+ "metadata": {}
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": 40,
68
  "source": [
69
+ "from torchvision.io.image import read_image, ImageReadMode\n",
70
+ "\n",
71
+ "# TODO create a batch of images\n",
72
+ "img = read_image('../../data/PMC4917066_amjcaserep-17-301-g001.jpg', mode=ImageReadMode.RGB)\n",
73
+ "tr_img = transforms(img)\n"
74
+ ],
75
+ "outputs": [],
76
+ "metadata": {}
77
  },
78
  {
79
  "cell_type": "code",
80
+ "execution_count": 37,
81
+ "source": [
82
+ "max_seq_length = 128\n",
83
+ "pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy()\n",
84
+ "# pixel_values = torch.stack([example[0] for example in examples]).numpy()\n",
85
+ "captions = [example[1] for example in examples]\n",
86
+ "inputs = tokenizer(captions, max_length=max_seq_length, padding=\"max_length\", return_tensors=\"np\",\n",
87
+ " truncation=True)\n",
88
+ "batch = {\n",
89
+ " \"pixel_values\": pixel_values,\n",
90
+ " \"input_ids\": inputs[\"input_ids\"],\n",
91
+ " \"attention_mask\": inputs[\"attention_mask\"],\n",
92
+ " }\n",
93
+ "logits = model(**batch, train=False)[0]"
94
+ ],
95
  "outputs": [
 
 
 
 
 
 
 
 
 
96
  {
97
  "output_type": "execute_result",
98
  "data": {
99
  "text/plain": [
100
+ "torch.Size([3, 224, 224])"
101
  ]
102
  },
103
  "metadata": {},
104
+ "execution_count": 37
105
  }
106
  ],
107
+ "metadata": {}
 
 
 
 
 
 
 
 
 
108
  }
109
  ]
110
  }