Navyabhat commited on
Commit
d43c6a1
1 Parent(s): 905e6c6

Upload 13 files

Browse files
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
Experiments/clip_expt.ipynb ADDED
@@ -0,0 +1,840 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "9fe51ce7-4c87-4186-9fd3-0fb18ac43e56",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from PIL import Image\n",
11
+ "import requests\n",
12
+ "from transformers import AutoProcessor, CLIPVisionModel"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 3,
18
+ "id": "0f4c21dd-4258-461d-8511-5be089d068a8",
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")\n",
23
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 4,
29
+ "id": "98b9f906-ffaa-4be4-8671-4ecf65f12c49",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
34
+ "# image = Image.open(requests.get(url, stream=True).raw)\n",
35
+ "image = Image.open(\"002579.jpg\")"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 17,
41
+ "id": "54b2e4ce-b77b-4314-87f6-ca2a1970fc79",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# image"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 18,
51
+ "id": "cdd65c58-007f-450b-8deb-f8b4f372a823",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# image = None"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": 5,
61
+ "id": "e9066c2e-c78b-49d1-979b-10d0f4f09441",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "inputs = processor(images=image, return_tensors=\"pt\", device_map=\"cuda:0\")"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 20,
71
+ "id": "e98b211d-29d9-4662-be0b-e011e89b0101",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "# inputs"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 6,
81
+ "id": "b030bd3d-4282-4074-98fe-97e658bd0f50",
82
+ "metadata": {},
83
+ "outputs": [
84
+ {
85
+ "data": {
86
+ "text/plain": [
87
+ "torch.Size([1, 3, 224, 224])"
88
+ ]
89
+ },
90
+ "execution_count": 6,
91
+ "metadata": {},
92
+ "output_type": "execute_result"
93
+ }
94
+ ],
95
+ "source": [
96
+ "inputs[\"pixel_values\"].shape"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "execution_count": 22,
102
+ "id": "0ce68f11-1c88-4dd7-8b17-0d1de5811fe6",
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "outputs = model(inputs[\"pixel_values\"].to(\"cuda:0\"))\n",
107
+ "last_hidden_state = outputs.last_hidden_state\n",
108
+ "pooled_output = outputs.pooler_output # pooled CLS states"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 23,
114
+ "id": "30cb0918-a30e-4246-b540-6b8e0d876807",
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "data": {
119
+ "text/plain": [
120
+ "torch.Size([1, 768])"
121
+ ]
122
+ },
123
+ "execution_count": 23,
124
+ "metadata": {},
125
+ "output_type": "execute_result"
126
+ }
127
+ ],
128
+ "source": [
129
+ "pooled_output.shape"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": 24,
135
+ "id": "6399543a-f23f-426d-8289-3bb52d293ece",
136
+ "metadata": {},
137
+ "outputs": [
138
+ {
139
+ "data": {
140
+ "text/plain": [
141
+ "torch.Size([1, 50, 768])"
142
+ ]
143
+ },
144
+ "execution_count": 24,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "last_hidden_state.shape"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 25,
156
+ "id": "19a70443-5942-4937-b3ea-6a52d76e2b08",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "torch.Size([1, 768])"
163
+ ]
164
+ },
165
+ "execution_count": 25,
166
+ "metadata": {},
167
+ "output_type": "execute_result"
168
+ }
169
+ ],
170
+ "source": [
171
+ "outputs[1].shape"
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 8,
177
+ "id": "fa13903f-a94a-4839-ae5a-8df4f55c68b6",
178
+ "metadata": {},
179
+ "outputs": [],
180
+ "source": [
181
+ "import torch\n",
182
+ "from torch import nn\n",
183
+ "from transformers import CLIPVisionConfig,CLIPPreTrainedModel"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "code",
188
+ "execution_count": 9,
189
+ "id": "b2bd9198-42f0-40c3-80e1-d167c0b038fb",
190
+ "metadata": {},
191
+ "outputs": [
192
+ {
193
+ "ename": "NameError",
194
+ "evalue": "name 'Optional' is not defined",
195
+ "output_type": "error",
196
+ "traceback": [
197
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
198
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
199
+ "Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mCLIPVisionModelWithProjection\u001b[39;00m(CLIPPreTrainedModel):\n\u001b[1;32m 2\u001b[0m config_class \u001b[38;5;241m=\u001b[39m CLIPVisionConfig\n\u001b[1;32m 3\u001b[0m main_input_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
200
+ "Cell \u001b[0;32mIn[9], line 20\u001b[0m, in \u001b[0;36mCLIPVisionModelWithProjection\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_embeddings\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m nn\u001b[38;5;241m.\u001b[39mModule:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model\u001b[38;5;241m.\u001b[39membeddings\u001b[38;5;241m.\u001b[39mpatch_embedding\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m---> 20\u001b[0m pixel_values: \u001b[43mOptional\u001b[49m[torch\u001b[38;5;241m.\u001b[39mFloatTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 21\u001b[0m output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m output_hidden_states: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 23\u001b[0m return_dict: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 24\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tuple, CLIPVisionModelOutput]:\n\u001b[1;32m 25\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 27\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model(\n\u001b[1;32m 28\u001b[0m pixel_values\u001b[38;5;241m=\u001b[39mpixel_values,\n\u001b[1;32m 29\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[1;32m 30\u001b[0m output_hidden_states\u001b[38;5;241m=\u001b[39moutput_hidden_states,\n\u001b[1;32m 31\u001b[0m return_dict\u001b[38;5;241m=\u001b[39mreturn_dict,\n\u001b[1;32m 32\u001b[0m )\n",
201
+ "\u001b[0;31mNameError\u001b[0m: name 'Optional' is not defined"
202
+ ]
203
+ }
204
+ ],
205
+ "source": [
206
+ "class CLIPVisionModelWithProjection(CLIPPreTrainedModel):\n",
207
+ " config_class = CLIPVisionConfig\n",
208
+ " main_input_name = \"pixel_values\"\n",
209
+ "\n",
210
+ " def __init__(self, config: CLIPVisionConfig):\n",
211
+ " super().__init__(config)\n",
212
+ "\n",
213
+ " self.vision_model = CLIPVisionTransformer(config)\n",
214
+ "\n",
215
+ " self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)\n",
216
+ "\n",
217
+ " # Initialize weights and apply final processing\n",
218
+ " self.post_init()\n",
219
+ "\n",
220
+ " def get_input_embeddings(self) -> nn.Module:\n",
221
+ " return self.vision_model.embeddings.patch_embedding\n",
222
+ "\n",
223
+ " def forward(\n",
224
+ " self,\n",
225
+ " pixel_values: Optional[torch.FloatTensor] = None,\n",
226
+ " output_attentions: Optional[bool] = None,\n",
227
+ " output_hidden_states: Optional[bool] = None,\n",
228
+ " return_dict: Optional[bool] = None,\n",
229
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
230
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
231
+ "\n",
232
+ " vision_outputs = self.vision_model(\n",
233
+ " pixel_values=pixel_values,\n",
234
+ " output_attentions=output_attentions,\n",
235
+ " output_hidden_states=output_hidden_states,\n",
236
+ " return_dict=return_dict,\n",
237
+ " )\n",
238
+ "\n",
239
+ " pooled_output = vision_outputs[1] # pooled_output\n",
240
+ "\n",
241
+ " image_embeds = self.visual_projection(pooled_output)\n",
242
+ "\n",
243
+ " if not return_dict:\n",
244
+ " outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
245
+ " return tuple(output for output in outputs if output is not None)\n",
246
+ "\n",
247
+ " return CLIPVisionModelOutput(\n",
248
+ " image_embeds=image_embeds,\n",
249
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
250
+ " hidden_states=vision_outputs.hidden_states,\n",
251
+ " attentions=vision_outputs.attentions,\n",
252
+ " )"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": 27,
258
+ "id": "68a9ee4a-d977-4725-842d-e64e0dd2f61d",
259
+ "metadata": {
260
+ "collapsed": true,
261
+ "jupyter": {
262
+ "outputs_hidden": true
263
+ }
264
+ },
265
+ "outputs": [
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
271
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
272
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
273
+ "Model config CLIPConfig {\n",
274
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
275
+ " \"architectures\": [\n",
276
+ " \"CLIPModel\"\n",
277
+ " ],\n",
278
+ " \"initializer_factor\": 1.0,\n",
279
+ " \"logit_scale_init_value\": 2.6592,\n",
280
+ " \"model_type\": \"clip\",\n",
281
+ " \"projection_dim\": 512,\n",
282
+ " \"text_config\": {\n",
283
+ " \"bos_token_id\": 0,\n",
284
+ " \"dropout\": 0.0,\n",
285
+ " \"eos_token_id\": 2,\n",
286
+ " \"model_type\": \"clip_text_model\"\n",
287
+ " },\n",
288
+ " \"transformers_version\": \"4.36.2\",\n",
289
+ " \"vision_config\": {\n",
290
+ " \"dropout\": 0.0,\n",
291
+ " \"model_type\": \"clip_vision_model\"\n",
292
+ " }\n",
293
+ "}\n",
294
+ "\n",
295
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
296
+ "All model checkpoint weights were used when initializing CLIPModel.\n",
297
+ "\n",
298
+ "All the weights of CLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
299
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPModel for predictions without further training.\n",
300
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
301
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
302
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
303
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
304
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
305
+ "Model config CLIPConfig {\n",
306
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
307
+ " \"architectures\": [\n",
308
+ " \"CLIPModel\"\n",
309
+ " ],\n",
310
+ " \"initializer_factor\": 1.0,\n",
311
+ " \"logit_scale_init_value\": 2.6592,\n",
312
+ " \"model_type\": \"clip\",\n",
313
+ " \"projection_dim\": 512,\n",
314
+ " \"text_config\": {\n",
315
+ " \"bos_token_id\": 0,\n",
316
+ " \"dropout\": 0.0,\n",
317
+ " \"eos_token_id\": 2,\n",
318
+ " \"model_type\": \"clip_text_model\"\n",
319
+ " },\n",
320
+ " \"transformers_version\": \"4.36.2\",\n",
321
+ " \"vision_config\": {\n",
322
+ " \"dropout\": 0.0,\n",
323
+ " \"model_type\": \"clip_vision_model\"\n",
324
+ " }\n",
325
+ "}\n",
326
+ "\n",
327
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
328
+ "size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
329
+ "crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
330
+ "Image processor CLIPImageProcessor {\n",
331
+ " \"crop_size\": {\n",
332
+ " \"height\": 224,\n",
333
+ " \"width\": 224\n",
334
+ " },\n",
335
+ " \"do_center_crop\": true,\n",
336
+ " \"do_convert_rgb\": true,\n",
337
+ " \"do_normalize\": true,\n",
338
+ " \"do_rescale\": true,\n",
339
+ " \"do_resize\": true,\n",
340
+ " \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
341
+ " \"image_mean\": [\n",
342
+ " 0.48145466,\n",
343
+ " 0.4578275,\n",
344
+ " 0.40821073\n",
345
+ " ],\n",
346
+ " \"image_processor_type\": \"CLIPImageProcessor\",\n",
347
+ " \"image_std\": [\n",
348
+ " 0.26862954,\n",
349
+ " 0.26130258,\n",
350
+ " 0.27577711\n",
351
+ " ],\n",
352
+ " \"resample\": 3,\n",
353
+ " \"rescale_factor\": 0.00392156862745098,\n",
354
+ " \"size\": {\n",
355
+ " \"shortest_edge\": 224\n",
356
+ " }\n",
357
+ "}\n",
358
+ "\n",
359
+ "loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
360
+ "loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
361
+ "loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
362
+ "loading file added_tokens.json from cache at None\n",
363
+ "loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
364
+ "loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
365
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
366
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
367
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
368
+ "Model config CLIPConfig {\n",
369
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
370
+ " \"architectures\": [\n",
371
+ " \"CLIPModel\"\n",
372
+ " ],\n",
373
+ " \"initializer_factor\": 1.0,\n",
374
+ " \"logit_scale_init_value\": 2.6592,\n",
375
+ " \"model_type\": \"clip\",\n",
376
+ " \"projection_dim\": 512,\n",
377
+ " \"text_config\": {\n",
378
+ " \"bos_token_id\": 0,\n",
379
+ " \"dropout\": 0.0,\n",
380
+ " \"eos_token_id\": 2,\n",
381
+ " \"model_type\": \"clip_text_model\"\n",
382
+ " },\n",
383
+ " \"transformers_version\": \"4.36.2\",\n",
384
+ " \"vision_config\": {\n",
385
+ " \"dropout\": 0.0,\n",
386
+ " \"model_type\": \"clip_vision_model\"\n",
387
+ " }\n",
388
+ "}\n",
389
+ "\n"
390
+ ]
391
+ }
392
+ ],
393
+ "source": [
394
+ "from PIL import Image\n",
395
+ "import requests\n",
396
+ "from transformers import AutoProcessor, CLIPModel\n",
397
+ "\n",
398
+ "model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
399
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
400
+ "\n",
401
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
402
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
403
+ "\n",
404
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
405
+ "\n",
406
+ "image_features = model.get_image_features(**inputs)"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": 29,
412
+ "id": "9ff63766-b706-452b-b735-bf9000fb9c20",
413
+ "metadata": {},
414
+ "outputs": [
415
+ {
416
+ "data": {
417
+ "text/plain": [
418
+ "torch.Size([1, 512])"
419
+ ]
420
+ },
421
+ "execution_count": 29,
422
+ "metadata": {},
423
+ "output_type": "execute_result"
424
+ }
425
+ ],
426
+ "source": [
427
+ "image_features.shape"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": 30,
433
+ "id": "82566e7b-3c91-421a-94c5-f1e2b3e91c8c",
434
+ "metadata": {
435
+ "collapsed": true,
436
+ "jupyter": {
437
+ "outputs_hidden": true
438
+ }
439
+ },
440
+ "outputs": [
441
+ {
442
+ "name": "stderr",
443
+ "output_type": "stream",
444
+ "text": [
445
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
446
+ "Model config CLIPVisionConfig {\n",
447
+ " \"attention_dropout\": 0.0,\n",
448
+ " \"dropout\": 0.0,\n",
449
+ " \"hidden_act\": \"quick_gelu\",\n",
450
+ " \"hidden_size\": 768,\n",
451
+ " \"image_size\": 224,\n",
452
+ " \"initializer_factor\": 1.0,\n",
453
+ " \"initializer_range\": 0.02,\n",
454
+ " \"intermediate_size\": 3072,\n",
455
+ " \"layer_norm_eps\": 1e-05,\n",
456
+ " \"model_type\": \"clip_vision_model\",\n",
457
+ " \"num_attention_heads\": 12,\n",
458
+ " \"num_channels\": 3,\n",
459
+ " \"num_hidden_layers\": 12,\n",
460
+ " \"patch_size\": 32,\n",
461
+ " \"projection_dim\": 512,\n",
462
+ " \"transformers_version\": \"4.36.2\"\n",
463
+ "}\n",
464
+ "\n",
465
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
466
+ "Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
467
+ "- This IS expected if you are initializing CLIPVisionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
468
+ "- This IS NOT expected if you are initializing CLIPVisionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
469
+ "All the weights of CLIPVisionModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
470
+ "If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPVisionModel for predictions without further training.\n",
471
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
472
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
473
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
474
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
475
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
476
+ "Model config CLIPConfig {\n",
477
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
478
+ " \"architectures\": [\n",
479
+ " \"CLIPModel\"\n",
480
+ " ],\n",
481
+ " \"initializer_factor\": 1.0,\n",
482
+ " \"logit_scale_init_value\": 2.6592,\n",
483
+ " \"model_type\": \"clip\",\n",
484
+ " \"projection_dim\": 512,\n",
485
+ " \"text_config\": {\n",
486
+ " \"bos_token_id\": 0,\n",
487
+ " \"dropout\": 0.0,\n",
488
+ " \"eos_token_id\": 2,\n",
489
+ " \"model_type\": \"clip_text_model\"\n",
490
+ " },\n",
491
+ " \"transformers_version\": \"4.36.2\",\n",
492
+ " \"vision_config\": {\n",
493
+ " \"dropout\": 0.0,\n",
494
+ " \"model_type\": \"clip_vision_model\"\n",
495
+ " }\n",
496
+ "}\n",
497
+ "\n",
498
+ "loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
499
+ "size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
500
+ "crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
501
+ "Image processor CLIPImageProcessor {\n",
502
+ " \"crop_size\": {\n",
503
+ " \"height\": 224,\n",
504
+ " \"width\": 224\n",
505
+ " },\n",
506
+ " \"do_center_crop\": true,\n",
507
+ " \"do_convert_rgb\": true,\n",
508
+ " \"do_normalize\": true,\n",
509
+ " \"do_rescale\": true,\n",
510
+ " \"do_resize\": true,\n",
511
+ " \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
512
+ " \"image_mean\": [\n",
513
+ " 0.48145466,\n",
514
+ " 0.4578275,\n",
515
+ " 0.40821073\n",
516
+ " ],\n",
517
+ " \"image_processor_type\": \"CLIPImageProcessor\",\n",
518
+ " \"image_std\": [\n",
519
+ " 0.26862954,\n",
520
+ " 0.26130258,\n",
521
+ " 0.27577711\n",
522
+ " ],\n",
523
+ " \"resample\": 3,\n",
524
+ " \"rescale_factor\": 0.00392156862745098,\n",
525
+ " \"size\": {\n",
526
+ " \"shortest_edge\": 224\n",
527
+ " }\n",
528
+ "}\n",
529
+ "\n",
530
+ "loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
531
+ "loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
532
+ "loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
533
+ "loading file added_tokens.json from cache at None\n",
534
+ "loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
535
+ "loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
536
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
537
+ "`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
538
+ "`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
539
+ "Model config CLIPConfig {\n",
540
+ " \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
541
+ " \"architectures\": [\n",
542
+ " \"CLIPModel\"\n",
543
+ " ],\n",
544
+ " \"initializer_factor\": 1.0,\n",
545
+ " \"logit_scale_init_value\": 2.6592,\n",
546
+ " \"model_type\": \"clip\",\n",
547
+ " \"projection_dim\": 512,\n",
548
+ " \"text_config\": {\n",
549
+ " \"bos_token_id\": 0,\n",
550
+ " \"dropout\": 0.0,\n",
551
+ " \"eos_token_id\": 2,\n",
552
+ " \"model_type\": \"clip_text_model\"\n",
553
+ " },\n",
554
+ " \"transformers_version\": \"4.36.2\",\n",
555
+ " \"vision_config\": {\n",
556
+ " \"dropout\": 0.0,\n",
557
+ " \"model_type\": \"clip_vision_model\"\n",
558
+ " }\n",
559
+ "}\n",
560
+ "\n"
561
+ ]
562
+ }
563
+ ],
564
+ "source": [
565
+ "from PIL import Image\n",
566
+ "import requests\n",
567
+ "from transformers import AutoProcessor, CLIPVisionModel\n",
568
+ "\n",
569
+ "model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
570
+ "processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
571
+ "\n",
572
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
573
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
574
+ "\n",
575
+ "inputs = processor(images=image, return_tensors=\"pt\")\n",
576
+ "\n",
577
+ "outputs = model(**inputs)\n",
578
+ "last_hidden_state = outputs.last_hidden_state\n",
579
+ "pooled_output = outputs.pooler_output # pooled CLS states"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "code",
584
+ "execution_count": 31,
585
+ "id": "bcf0a7b3-6cbb-492e-bc2c-42e3edbe6a0c",
586
+ "metadata": {},
587
+ "outputs": [
588
+ {
589
+ "data": {
590
+ "text/plain": [
591
+ "torch.Size([1, 768])"
592
+ ]
593
+ },
594
+ "execution_count": 31,
595
+ "metadata": {},
596
+ "output_type": "execute_result"
597
+ }
598
+ ],
599
+ "source": [
600
+ "pooled_output.shape"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "code",
605
+ "execution_count": 10,
606
+ "id": "67240294-c7a0-4e94-a8c1-86bfe1b21977",
607
+ "metadata": {},
608
+ "outputs": [],
609
+ "source": [
610
+ "from transformers import CLIPPreTrainedModel\n",
611
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
612
+ "from typing import Optional, Union, Tuple"
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "code",
617
+ "execution_count": 54,
618
+ "id": "cc9b20db-7f84-44c3-9c78-e84164ccc192",
619
+ "metadata": {},
620
+ "outputs": [],
621
+ "source": [
622
+ "class VisionLanguageConnector(nn.Module):\n",
623
+ " def __init__(self, hidden_size, projection_dim):\n",
624
+ " super().__init__()\n",
625
+ " self.mlp = nn.Sequential(\n",
626
+ " nn.Linear(hidden_size, hidden_size, bias=False),\n",
627
+ " nn.GELU(),\n",
628
+ " nn.Linear(hidden_size, projection_dim, bias=False)\n",
629
+ " )\n",
630
+ "\n",
631
+ " def forward(self, x):\n",
632
+ " return self.mlp(x)\n",
633
+ " \n",
634
+ "class ClipWithProjection(CLIPPreTrainedModel):\n",
635
+ " config_class = CLIPVisionConfig\n",
636
+ " main_input_name = \"pixel_values\"\n",
637
+ "\n",
638
+ " def __init__(self, config: CLIPVisionConfig):\n",
639
+ " super().__init__(config)\n",
640
+ "\n",
641
+ " self.vision_model = CLIPVisionTransformer(config)\n",
642
+ " self.vision_model.\n",
643
+ " self.vision_language_connector = VisionLanguageConnector(config.hidden_size, config.projection_dim)\n",
644
+ "\n",
645
+ " # Initialize weights and apply final processing\n",
646
+ " self.post_init()\n",
647
+ "\n",
648
+ " def forward(\n",
649
+ " self,\n",
650
+ " pixel_values: Optional[torch.FloatTensor] = None,\n",
651
+ " output_attentions: Optional[bool] = None,\n",
652
+ " output_hidden_states: Optional[bool] = None,\n",
653
+ " return_dict: Optional[bool] = None,\n",
654
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
655
+ " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
656
+ "\n",
657
+ " vision_outputs = self.vision_model(\n",
658
+ " pixel_values=pixel_values,\n",
659
+ " output_attentions=output_attentions,\n",
660
+ " output_hidden_states=output_hidden_states,\n",
661
+ " return_dict=return_dict,\n",
662
+ " )\n",
663
+ "\n",
664
+ " pooled_output = vision_outputs[1] # pooled_output\n",
665
+ "\n",
666
+ " image_embeds = self.vision_language_connector(pooled_output)\n",
667
+ "\n",
668
+ " if not return_dict:\n",
669
+ " outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
670
+ " return tuple(output for output in outputs if output is not None)\n",
671
+ "\n",
672
+ " return CLIPVisionModelOutput(\n",
673
+ " image_embeds=image_embeds,\n",
674
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
675
+ " hidden_states=vision_outputs.hidden_states,\n",
676
+ " attentions=vision_outputs.attentions,\n",
677
+ " )"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": 55,
683
+ "id": "a4892ab8-39d2-41c9-ad2a-04711c22b95f",
684
+ "metadata": {
685
+ "collapsed": true,
686
+ "jupyter": {
687
+ "outputs_hidden": true
688
+ }
689
+ },
690
+ "outputs": [
691
+ {
692
+ "name": "stderr",
693
+ "output_type": "stream",
694
+ "text": [
695
+ "loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
696
+ "Model config CLIPVisionConfig {\n",
697
+ " \"attention_dropout\": 0.0,\n",
698
+ " \"dropout\": 0.0,\n",
699
+ " \"hidden_act\": \"quick_gelu\",\n",
700
+ " \"hidden_size\": 768,\n",
701
+ " \"image_size\": 224,\n",
702
+ " \"initializer_factor\": 1.0,\n",
703
+ " \"initializer_range\": 0.02,\n",
704
+ " \"intermediate_size\": 3072,\n",
705
+ " \"layer_norm_eps\": 1e-05,\n",
706
+ " \"model_type\": \"clip_vision_model\",\n",
707
+ " \"num_attention_heads\": 12,\n",
708
+ " \"num_channels\": 3,\n",
709
+ " \"num_hidden_layers\": 12,\n",
710
+ " \"patch_size\": 32,\n",
711
+ " \"projection_dim\": 512,\n",
712
+ " \"transformers_version\": \"4.36.2\"\n",
713
+ "}\n",
714
+ "\n",
715
+ "loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
716
+ "Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing ClipWithProjection: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
717
+ "- This IS expected if you are initializing ClipWithProjection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
718
+ "- This IS NOT expected if you are initializing ClipWithProjection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
719
+ "Some weights of ClipWithProjection were not initialized from the model checkpoint at openai/clip-vit-base-patch32 and are newly initialized: ['vision_language_connector.mlp.2.weight', 'vision_language_connector.mlp.0.weight']\n",
720
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
721
+ ]
722
+ }
723
+ ],
724
+ "source": [
725
+ "model = ClipWithProjection.from_pretrained(\"openai/clip-vit-base-patch32\")"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": 56,
731
+ "id": "588ef914-5be9-49e1-b68d-b899e0e74edd",
732
+ "metadata": {},
733
+ "outputs": [
734
+ {
735
+ "data": {
736
+ "text/plain": [
737
+ "768"
738
+ ]
739
+ },
740
+ "execution_count": 56,
741
+ "metadata": {},
742
+ "output_type": "execute_result"
743
+ }
744
+ ],
745
+ "source": [
746
+ "model.config.hidden_size"
747
+ ]
748
+ },
749
+ {
750
+ "cell_type": "code",
751
+ "execution_count": 57,
752
+ "id": "05d95b9e-9831-4415-860e-94793e29d210",
753
+ "metadata": {},
754
+ "outputs": [],
755
+ "source": [
756
+ "outputs = model(**inputs)"
757
+ ]
758
+ },
759
+ {
760
+ "cell_type": "code",
761
+ "execution_count": 61,
762
+ "id": "185b1bff-6ffe-4cce-9255-ee7629feba54",
763
+ "metadata": {},
764
+ "outputs": [
765
+ {
766
+ "data": {
767
+ "text/plain": [
768
+ "torch.Size([1, 512])"
769
+ ]
770
+ },
771
+ "execution_count": 61,
772
+ "metadata": {},
773
+ "output_type": "execute_result"
774
+ }
775
+ ],
776
+ "source": [
777
+ "outputs[0].shape"
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "code",
782
+ "execution_count": null,
783
+ "id": "04414a35-c7b3-4986-a79e-1d363916caa4",
784
+ "metadata": {},
785
+ "outputs": [],
786
+ "source": []
787
+ },
788
+ {
789
+ "cell_type": "code",
790
+ "execution_count": 1,
791
+ "id": "485dbbcb-06df-4926-b257-dfd1a4081d44",
792
+ "metadata": {},
793
+ "outputs": [
794
+ {
795
+ "ename": "NameError",
796
+ "evalue": "name 'outputs' is not defined",
797
+ "output_type": "error",
798
+ "traceback": [
799
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
800
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
801
+ "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43moutputs\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
802
+ "\u001b[0;31mNameError\u001b[0m: name 'outputs' is not defined"
803
+ ]
804
+ }
805
+ ],
806
+ "source": [
807
+ "outputs[0]"
808
+ ]
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "execution_count": null,
813
+ "id": "f983313c-8e0f-4805-af14-25bb69afd04c",
814
+ "metadata": {},
815
+ "outputs": [],
816
+ "source": []
817
+ }
818
+ ],
819
+ "metadata": {
820
+ "kernelspec": {
821
+ "display_name": "Python 3 (ipykernel)",
822
+ "language": "python",
823
+ "name": "python3"
824
+ },
825
+ "language_info": {
826
+ "codemirror_mode": {
827
+ "name": "ipython",
828
+ "version": 3
829
+ },
830
+ "file_extension": ".py",
831
+ "mimetype": "text/x-python",
832
+ "name": "python",
833
+ "nbconvert_exporter": "python",
834
+ "pygments_lexer": "ipython3",
835
+ "version": "3.10.12"
836
+ }
837
+ },
838
+ "nbformat": 4,
839
+ "nbformat_minor": 5
840
+ }
Experiments/eval.ipynb ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 4,
6
+ "id": "215cfd2f-62b0-4a86-a407-777a1d32597f",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from PIL import Image\n",
19
+ "import requests\n",
20
+ "\n",
21
+ "import torch\n",
22
+ "from torch import nn\n",
23
+ "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
24
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
25
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
26
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": 5,
32
+ "id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79",
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "from llava_phi import LlavaPhiForCausalLM"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": 3,
42
+ "id": "587883e1-3419-4b14-b16b-38fabbc8bfaa",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 4,
52
+ "id": "0e27a7db-e2ab-4d65-b21d-497222e318ad",
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 5,
62
+ "id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# prompt = \"<image>\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
67
+ "# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
68
+ "# image = Image.open(requests.get(url, stream=True).raw)"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 6,
74
+ "id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca",
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "# # Generate\n",
79
+ "# generate_ids = model.generate(**inputs, max_length=30)\n",
80
+ "# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
81
+ ]
82
+ },
83
+ {
84
+ "cell_type": "code",
85
+ "execution_count": 6,
86
+ "id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63",
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "from llava_phi.model.builder import load_pretrained_model\n",
91
+ "from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
92
+ "from llava_phi.utils import disable_torch_init\n",
93
+ "from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
94
+ "from llava_phi.conversation import conv_templates, SeparatorStyle"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 11,
100
+ "id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n",
105
+ "model_name = get_model_name_from_path(model_path)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": 12,
111
+ "id": "42fd5721-75a7-475b-bd30-5ee23aeaac64",
112
+ "metadata": {},
113
+ "outputs": [
114
+ {
115
+ "data": {
116
+ "text/plain": [
117
+ "'llavaPhi-v0-3b-finetune_checkpoint-4000'"
118
+ ]
119
+ },
120
+ "execution_count": 12,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "model_name"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 13,
132
+ "id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f",
133
+ "metadata": {},
134
+ "outputs": [
135
+ {
136
+ "name": "stderr",
137
+ "output_type": "stream",
138
+ "text": [
139
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
140
+ ]
141
+ },
142
+ {
143
+ "name": "stdout",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "load llaVA-Phi MLLM!!!\n"
147
+ ]
148
+ },
149
+ {
150
+ "name": "stderr",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
154
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
155
+ ]
156
+ },
157
+ {
158
+ "data": {
159
+ "application/vnd.jupyter.widget-view+json": {
160
+ "model_id": "20b86f2c01744081b537620c8780f12e",
161
+ "version_major": 2,
162
+ "version_minor": 0
163
+ },
164
+ "text/plain": [
165
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
166
+ ]
167
+ },
168
+ "metadata": {},
169
+ "output_type": "display_data"
170
+ },
171
+ {
172
+ "name": "stdout",
173
+ "output_type": "stream",
174
+ "text": [
175
+ "{'device_map': 'cuda'}\n"
176
+ ]
177
+ }
178
+ ],
179
+ "source": [
180
+ "tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": 14,
186
+ "id": "4e46221e-0907-453e-8126-76199828493e",
187
+ "metadata": {},
188
+ "outputs": [],
189
+ "source": [
190
+ "qs = \"What's the content of the image?\"\n",
191
+ "qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": 15,
197
+ "id": "07355444-0eb8-4d4d-ad50-48b91c969664",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "conv = conv_templates[\"default\"].copy()\n",
202
+ "conv.append_message(conv.roles[0], qs)\n",
203
+ "conv.append_message(conv.roles[1], None)\n",
204
+ "prompt = conv.get_prompt()"
205
+ ]
206
+ },
207
+ {
208
+ "cell_type": "code",
209
+ "execution_count": 16,
210
+ "id": "ccb5674f-aff8-456e-b61b-1d167864f1a6",
211
+ "metadata": {},
212
+ "outputs": [
213
+ {
214
+ "data": {
215
+ "text/plain": [
216
+ "\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>\\nWhat's the content of the image? ASSISTANT:\""
217
+ ]
218
+ },
219
+ "execution_count": 16,
220
+ "metadata": {},
221
+ "output_type": "execute_result"
222
+ }
223
+ ],
224
+ "source": [
225
+ "prompt"
226
+ ]
227
+ },
228
+ {
229
+ "cell_type": "code",
230
+ "execution_count": 17,
231
+ "id": "a89cc181-2214-4844-b966-164a41744e54",
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
236
+ "image = Image.open(requests.get(url, stream=True).raw)\n",
237
+ "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
238
+ "\n",
239
+ "input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
240
+ "\n",
241
+ "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 25,
247
+ "id": "0d519851-64d4-4cf5-b2eb-19474f9aa260",
248
+ "metadata": {},
249
+ "outputs": [
250
+ {
251
+ "data": {
252
+ "text/plain": [
253
+ "torch.Size([1, 55])"
254
+ ]
255
+ },
256
+ "execution_count": 25,
257
+ "metadata": {},
258
+ "output_type": "execute_result"
259
+ }
260
+ ],
261
+ "source": [
262
+ "input_ids.shape"
263
+ ]
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "execution_count": 24,
268
+ "id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b",
269
+ "metadata": {},
270
+ "outputs": [
271
+ {
272
+ "name": "stderr",
273
+ "output_type": "stream",
274
+ "text": [
275
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
276
+ ]
277
+ }
278
+ ],
279
+ "source": [
280
+ "from datasets import load_dataset\n",
281
+ "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
282
+ "audio = audio_ds[0][\"audio\"]\n",
283
+ "\n",
284
+ "whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
285
+ "audio_embed = whisper_w_proj(audio)[\"input_ids\"]"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "execution_count": 28,
291
+ "id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93",
292
+ "metadata": {},
293
+ "outputs": [
294
+ {
295
+ "data": {
296
+ "text/plain": [
297
+ "(torch.Size([1, 33]), device(type='cpu'))"
298
+ ]
299
+ },
300
+ "execution_count": 28,
301
+ "metadata": {},
302
+ "output_type": "execute_result"
303
+ }
304
+ ],
305
+ "source": [
306
+ "audio_embed.shape, audio_embed.device"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": 29,
312
+ "id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 30,
322
+ "id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c",
323
+ "metadata": {},
324
+ "outputs": [
325
+ {
326
+ "data": {
327
+ "text/plain": [
328
+ "torch.Size([1, 88])"
329
+ ]
330
+ },
331
+ "execution_count": 30,
332
+ "metadata": {},
333
+ "output_type": "execute_result"
334
+ }
335
+ ],
336
+ "source": [
337
+ "input_ids.shape"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 31,
343
+ "id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "with torch.inference_mode():\n",
348
+ " output_ids = model.generate(\n",
349
+ " input_ids,\n",
350
+ " images=image_tensor,\n",
351
+ " do_sample=True,\n",
352
+ " temperature=0.2,\n",
353
+ " max_new_tokens=1024,\n",
354
+ " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
355
+ " pad_token_id=tokenizer.eos_token_id, # Pad token\n",
356
+ " use_cache=True,\n",
357
+ " )"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": 32,
363
+ "id": "741e8da5-0d18-4c11-b559-76054ce4ca3a",
364
+ "metadata": {},
365
+ "outputs": [
366
+ {
367
+ "name": "stdout",
368
+ "output_type": "stream",
369
+ "text": [
370
+ "is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n"
371
+ ]
372
+ }
373
+ ],
374
+ "source": [
375
+ "input_token_len = input_ids.shape[1]\n",
376
+ "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
377
+ "if n_diff_input_output > 0:\n",
378
+ " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
379
+ "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
380
+ "outputs = outputs.strip()\n",
381
+ "if outputs.endswith(stop_str):\n",
382
+ " outputs = outputs[:-len(stop_str)]\n",
383
+ "outputs = outputs.strip()\n",
384
+ "print(outputs)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 20,
390
+ "id": "69d494d4-d768-4645-b4d6-5c455791b50d",
391
+ "metadata": {},
392
+ "outputs": [],
393
+ "source": [
394
+ "# image"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "8a340856-a13f-4b18-9911-126a4ba37816",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": []
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": []
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 52,
416
+ "id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1",
417
+ "metadata": {},
418
+ "outputs": [],
419
+ "source": [
420
+ "class AudioLanguageConnector:\n",
421
+ " def __init__(self, projection_dim):\n",
422
+ " model_name = \"microsoft/phi-2\"\n",
423
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
424
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
425
+ " self.phi2_tokenizer.max_length = projection_dim\n",
426
+ "\n",
427
+ " def __call__(self, text):\n",
428
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
429
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
430
+ " return tokens\n",
431
+ " \n",
432
+ "\n",
433
+ "class WhisperWithProjection:\n",
434
+ " def __init__(self, projection_dim, device):\n",
435
+ " self.device = device\n",
436
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
437
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
438
+ " self.model.config.forced_decoder_ids = None\n",
439
+ " self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
440
+ " \n",
441
+ " def __call__(self, audio):\n",
442
+ " input_features = self.processor(audio[\"array\"],\n",
443
+ " sampling_rate=audio[\"sampling_rate\"],\n",
444
+ " return_tensors=\"pt\").input_features\n",
445
+ " # generate token ids\n",
446
+ " predicted_ids = self.model.generate(input_features.to(self.device))\n",
447
+ " # decode token ids to text \n",
448
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
449
+ "\n",
450
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
451
+ " return audio_embeddings.to(self.device)"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 53,
457
+ "id": "75e24be0-b236-4047-83ef-5c344e262476",
458
+ "metadata": {},
459
+ "outputs": [],
460
+ "source": [
461
+ "class MultiModalPhi2:\n",
462
+ " def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n",
463
+ " temperature=0.2,\n",
464
+ " max_new_tokens=1024,\n",
465
+ " device=\"cuda\"):\n",
466
+ " self.temperature = temperature\n",
467
+ " self.max_new_tokens = max_new_tokens\n",
468
+ " self.device = device\n",
469
+ " model_name = get_model_name_from_path(model_path)\n",
470
+ " self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n",
471
+ " self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
472
+ " \n",
473
+ " \n",
474
+ " def __call__(self, text, audio, image):\n",
475
+ " qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
476
+ " conv = conv_templates[\"default\"].copy()\n",
477
+ " conv.append_message(conv.roles[0], qs)\n",
478
+ " conv.append_message(conv.roles[1], None)\n",
479
+ " prompt = conv.get_prompt()\n",
480
+ "\n",
481
+ " image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
482
+ " \n",
483
+ " input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
484
+ "\n",
485
+ " audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n",
486
+ " \n",
487
+ " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
488
+ "\n",
489
+ " input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
490
+ "\n",
491
+ " with torch.inference_mode():\n",
492
+ " output_ids = self.model.generate(\n",
493
+ " input_ids,\n",
494
+ " images=image_tensor,\n",
495
+ " do_sample=True,\n",
496
+ " temperature=self.temperature,\n",
497
+ " max_new_tokens=self.max_new_tokens,\n",
498
+ " eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
499
+ " pad_token_id=tokenizer.eos_token_id, # Pad token\n",
500
+ " use_cache=True,\n",
501
+ " )\n",
502
+ "\n",
503
+ " input_token_len = input_ids.shape[1]\n",
504
+ " n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
505
+ " if n_diff_input_output > 0:\n",
506
+ " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
507
+ " outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
508
+ " outputs = outputs.strip()\n",
509
+ " if outputs.endswith(stop_str):\n",
510
+ " outputs = outputs[:-len(stop_str)]\n",
511
+ " outputs = outputs.strip()\n",
512
+ " return outputs"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 54,
518
+ "id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172",
519
+ "metadata": {},
520
+ "outputs": [
521
+ {
522
+ "name": "stderr",
523
+ "output_type": "stream",
524
+ "text": [
525
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
526
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
527
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
528
+ ]
529
+ },
530
+ {
531
+ "name": "stdout",
532
+ "output_type": "stream",
533
+ "text": [
534
+ "load llaVA-Phi MLLM!!!\n"
535
+ ]
536
+ },
537
+ {
538
+ "data": {
539
+ "application/vnd.jupyter.widget-view+json": {
540
+ "model_id": "492c17cf54f34d4d9e4f288fc9e72e79",
541
+ "version_major": 2,
542
+ "version_minor": 0
543
+ },
544
+ "text/plain": [
545
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
546
+ ]
547
+ },
548
+ "metadata": {},
549
+ "output_type": "display_data"
550
+ },
551
+ {
552
+ "name": "stdout",
553
+ "output_type": "stream",
554
+ "text": [
555
+ "{'device_map': 'cuda'}\n"
556
+ ]
557
+ },
558
+ {
559
+ "name": "stderr",
560
+ "output_type": "stream",
561
+ "text": [
562
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
563
+ ]
564
+ }
565
+ ],
566
+ "source": [
567
+ "multimodal_phi2 = MultiModalPhi2()"
568
+ ]
569
+ },
570
+ {
571
+ "cell_type": "code",
572
+ "execution_count": 57,
573
+ "id": "9a6de0b0-a231-4d50-88e8-e40c6f7216c3",
574
+ "metadata": {},
575
+ "outputs": [],
576
+ "source": [
577
+ "text = \"tell me about the audio\""
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 58,
583
+ "id": "b4919948-6a75-4d19-ba95-9ba233a7d3d9",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/plain": [
589
+ "'is a popular Japanese drama series featuring a man in a red and white costume, who is dressed as Santa Claus, is walking down the street. The scene takes place in a busy city environment, with people walking and standing on the sidewalk, likely enjoying the festive atmosphere and the festive atmosphere.'"
590
+ ]
591
+ },
592
+ "execution_count": 58,
593
+ "metadata": {},
594
+ "output_type": "execute_result"
595
+ }
596
+ ],
597
+ "source": [
598
+ "multimodal_phi2(text, audio, image)"
599
+ ]
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": null,
604
+ "id": "590f2d64-62ed-4e6f-b7c8-b0cf68aecaab",
605
+ "metadata": {},
606
+ "outputs": [],
607
+ "source": []
608
+ },
609
+ {
610
+ "cell_type": "code",
611
+ "execution_count": 64,
612
+ "id": "c921eb63-feb5-4fa9-993b-2faeb6dfe1db",
613
+ "metadata": {},
614
+ "outputs": [],
615
+ "source": [
616
+ "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor"
617
+ ]
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "execution_count": 65,
622
+ "id": "b470a2c4-806a-435d-9fc2-f17448dbe5fc",
623
+ "metadata": {},
624
+ "outputs": [],
625
+ "source": [
626
+ "from llava_phi.model import LlavaPhiConfig"
627
+ ]
628
+ },
629
+ {
630
+ "cell_type": "code",
631
+ "execution_count": 66,
632
+ "id": "4f7bc91a-0a41-45e5-92a4-daa1e3eea0da",
633
+ "metadata": {},
634
+ "outputs": [
635
+ {
636
+ "name": "stderr",
637
+ "output_type": "stream",
638
+ "text": [
639
+ "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
640
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
641
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
642
+ ]
643
+ },
644
+ {
645
+ "data": {
646
+ "application/vnd.jupyter.widget-view+json": {
647
+ "model_id": "993bc3a38cb84de4a2e3a79a3448c4d6",
648
+ "version_major": 2,
649
+ "version_minor": 0
650
+ },
651
+ "text/plain": [
652
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
653
+ ]
654
+ },
655
+ "metadata": {},
656
+ "output_type": "display_data"
657
+ }
658
+ ],
659
+ "source": [
660
+ "device_map = \"cuda:0\"\n",
661
+ "load_8bit = False\n",
662
+ "load_4bit = False\n",
663
+ "kwargs = {\"device_map\": device_map}\n",
664
+ "if load_8bit:\n",
665
+ " kwargs['load_in_8bit'] = True\n",
666
+ "elif load_4bit:\n",
667
+ " kwargs['load_in_4bit'] = True\n",
668
+ " kwargs['quantization_config'] = BitsAndBytesConfig(\n",
669
+ " load_in_4bit=True,\n",
670
+ " bnb_4bit_compute_dtype=torch.float16,\n",
671
+ " bnb_4bit_use_double_quant=True,\n",
672
+ " bnb_4bit_quant_type='nf4'\n",
673
+ " )\n",
674
+ "config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)\n",
675
+ "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n",
676
+ "model = LlavaPhiForCausalLM.from_pretrained(\n",
677
+ " model_path, \n",
678
+ " config=config, \n",
679
+ " use_safetensors=True, \n",
680
+ " **kwargs).to(\"cuda\")\n",
681
+ "image_processor = CLIPImageProcessor.from_pretrained(model_path)\n",
682
+ "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
683
+ "mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
684
+ "\n",
685
+ "# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200\n",
686
+ "if mm_use_im_patch_token:\n",
687
+ " tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
688
+ "if mm_use_im_start_end:\n",
689
+ " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
690
+ " \n",
691
+ "if hasattr(model.config, \"max_sequence_length\"):\n",
692
+ " context_len = model.config.max_sequence_length\n",
693
+ "else:\n",
694
+ " context_len = 2048"
695
+ ]
696
+ },
697
+ {
698
+ "cell_type": "code",
699
+ "execution_count": 70,
700
+ "id": "99355837-a297-4a25-aeb3-1670af7e9251",
701
+ "metadata": {},
702
+ "outputs": [
703
+ {
704
+ "ename": "KeyboardInterrupt",
705
+ "evalue": "",
706
+ "output_type": "error",
707
+ "traceback": [
708
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
709
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
710
+ "Cell \u001b[0;32mIn[70], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
711
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m 2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 2373\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m 2374\u001b[0m \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m 2375\u001b[0m \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2378\u001b[0m save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
712
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m 251\u001b[0m tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m 252\u001b[0m filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m 253\u001b[0m metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 254\u001b[0m ):\n\u001b[1;32m 255\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;124;03m Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
713
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
714
+ ]
715
+ }
716
+ ],
717
+ "source": [
718
+ "model.save_pretrained(\"Llava-Phi-Checkpoint\")"
719
+ ]
720
+ },
721
+ {
722
+ "cell_type": "code",
723
+ "execution_count": null,
724
+ "id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca",
725
+ "metadata": {},
726
+ "outputs": [],
727
+ "source": [
728
+ "model.push_to_hub(\"RaviNaik/Llava-Phi2\")"
729
+ ]
730
+ },
731
+ {
732
+ "cell_type": "code",
733
+ "execution_count": 73,
734
+ "id": "382f74b0-2967-408a-badc-a90918810d74",
735
+ "metadata": {},
736
+ "outputs": [
737
+ {
738
+ "data": {
739
+ "text/plain": [
740
+ "CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)"
741
+ ]
742
+ },
743
+ "execution_count": 73,
744
+ "metadata": {},
745
+ "output_type": "execute_result"
746
+ }
747
+ ],
748
+ "source": [
749
+ "tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": null,
755
+ "id": "b851459b-d3ac-4fb8-99b6-17a648adc41f",
756
+ "metadata": {},
757
+ "outputs": [],
758
+ "source": []
759
+ }
760
+ ],
761
+ "metadata": {
762
+ "kernelspec": {
763
+ "display_name": "Python 3 (ipykernel)",
764
+ "language": "python",
765
+ "name": "python3"
766
+ },
767
+ "language_info": {
768
+ "codemirror_mode": {
769
+ "name": "ipython",
770
+ "version": 3
771
+ },
772
+ "file_extension": ".py",
773
+ "mimetype": "text/x-python",
774
+ "name": "python",
775
+ "nbconvert_exporter": "python",
776
+ "pygments_lexer": "ipython3",
777
+ "version": "3.10.12"
778
+ }
779
+ },
780
+ "nbformat": 4,
781
+ "nbformat_minor": 5
782
+ }
Experiments/instruct_150k_data.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Experiments/instruct_data.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import Dataset, IterableDataset
2
+ from PIL import Image
3
+
4
+ # ChatML format
5
+ templates = {
6
+ "assistant": "<|im_start|>assistant\n{msg}<|im_end|>", # message by assistant
7
+ "user": "<|im_start|>user\n{msg}<|im_end|>" # message by user
8
+ }
9
+
10
+ ds = Dataset.from_json("llava_instruct_150k.json", split="train")
11
+ ds_stream = ds.to_iterable_dataset()
12
+
13
+
14
+ def get_image(image_path):
15
+ image_path = f"train2014/COCO_train2014_{image_path}"
16
+ img = Image.open(image_path)
17
+ return img
18
+
19
+ def get_chatml_text(conversations):
20
+ chatml_text = ""
21
+ for conversation in conversations:
22
+ role = conversation["from"]
23
+ role = "user" if role == "human" else "assistant"
24
+ content = conversation["value"]
25
+
26
+ formatted_text = templates[role].format(msg=content)
27
+ chatml_text += formatted_text + "\n"
28
+ return chatml_text
29
+
30
+ def instruct_data_generator():
31
+ for sample in ds_stream:
32
+ image_path = sample["image"]
33
+ conversations = sample["conversations"]
34
+
35
+ image = get_image(image_path)
36
+ text = get_chatml_text(conversations)
37
+ yield {"text": text, "image": image}
38
+
39
+ instruct_ds = IterableDataset.from_generator(generator=instruct_data_generator)
Experiments/llava_exp.ipynb ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "99576983-f881-47c8-8b5e-c6f561a93e71",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import transformers"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "58ba19f2-4b91-4f90-a33d-4c1ed17e202a",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, PhiConfig\n",
21
+ "\n",
22
+ "# Initializing a CLIP-vision config\n",
23
+ "vision_config = CLIPVisionConfig()\n",
24
+ "\n",
25
+ "# Initializing a Llama config\n",
26
+ "text_config = PhiConfig()\n",
27
+ "\n",
28
+ "# Initializing a Llava llava-1.5-7b style configuration\n",
29
+ "configuration = LlavaConfig(vision_config, text_config)\n",
30
+ "\n",
31
+ "# Initializing a model from the llava-1.5-7b style configuration\n",
32
+ "model = LlavaForConditionalGeneration(configuration)\n",
33
+ "\n",
34
+ "# Accessing the model configuration\n",
35
+ "configuration = model.config"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 5,
41
+ "id": "a806a07a-fe72-45a3-8ceb-8e942c6c845d",
42
+ "metadata": {},
43
+ "outputs": [
44
+ {
45
+ "data": {
46
+ "text/plain": [
47
+ "LlavaConfig {\n",
48
+ " \"ignore_index\": -100,\n",
49
+ " \"image_token_index\": 32000,\n",
50
+ " \"model_type\": \"llava\",\n",
51
+ " \"projector_hidden_act\": \"gelu\",\n",
52
+ " \"text_config\": {\n",
53
+ " \"embd_pdrop\": 0.0,\n",
54
+ " \"hidden_act\": \"gelu_new\",\n",
55
+ " \"hidden_size\": 2048,\n",
56
+ " \"intermediate_size\": 8192,\n",
57
+ " \"layer_norm_eps\": 1e-05,\n",
58
+ " \"model_type\": \"phi\",\n",
59
+ " \"num_hidden_layers\": 24,\n",
60
+ " \"partial_rotary_factor\": 0.5,\n",
61
+ " \"qk_layernorm\": false,\n",
62
+ " \"resid_pdrop\": 0.0,\n",
63
+ " \"vocab_size\": 51200\n",
64
+ " },\n",
65
+ " \"transformers_version\": \"4.36.2\",\n",
66
+ " \"vision_config\": {\n",
67
+ " \"hidden_size\": 768,\n",
68
+ " \"image_size\": 224,\n",
69
+ " \"intermediate_size\": 3072,\n",
70
+ " \"model_type\": \"clip_vision_model\",\n",
71
+ " \"num_attention_heads\": 12,\n",
72
+ " \"num_hidden_layers\": 12,\n",
73
+ " \"patch_size\": 32,\n",
74
+ " \"projection_dim\": 512\n",
75
+ " },\n",
76
+ " \"vision_feature_layer\": -2,\n",
77
+ " \"vision_feature_select_strategy\": \"default\",\n",
78
+ " \"vocab_size\": 32000\n",
79
+ "}"
80
+ ]
81
+ },
82
+ "execution_count": 5,
83
+ "metadata": {},
84
+ "output_type": "execute_result"
85
+ }
86
+ ],
87
+ "source": [
88
+ "model.config"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 6,
94
+ "id": "79efbc6b-f005-4a5c-82a1-112fa37f1904",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "name": "stdout",
99
+ "output_type": "stream",
100
+ "text": [
101
+ "Cloning into 'llava-phi'...\n",
102
+ "remote: Enumerating objects: 151, done.\u001b[K\n",
103
+ "remote: Counting objects: 100% (151/151), done.\u001b[K\n",
104
+ "remote: Compressing objects: 100% (116/116), done.\u001b[K\n",
105
+ "remote: Total 151 (delta 36), reused 133 (delta 25), pack-reused 0\u001b[K\n",
106
+ "Receiving objects: 100% (151/151), 333.89 KiB | 112.00 KiB/s, done.\n",
107
+ "Resolving deltas: 100% (36/36), done.\n"
108
+ ]
109
+ }
110
+ ],
111
+ "source": [
112
+ "!git clone https://github.com/zhuyiche/llava-phi.git"
113
+ ]
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "execution_count": null,
118
+ "id": "cf827184-f334-4d86-ace1-fe9c92f84d66",
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": []
122
+ }
123
+ ],
124
+ "metadata": {
125
+ "kernelspec": {
126
+ "display_name": "Python 3 (ipykernel)",
127
+ "language": "python",
128
+ "name": "python3"
129
+ },
130
+ "language_info": {
131
+ "codemirror_mode": {
132
+ "name": "ipython",
133
+ "version": 3
134
+ },
135
+ "file_extension": ".py",
136
+ "mimetype": "text/x-python",
137
+ "name": "python",
138
+ "nbconvert_exporter": "python",
139
+ "pygments_lexer": "ipython3",
140
+ "version": "3.10.12"
141
+ }
142
+ },
143
+ "nbformat": 4,
144
+ "nbformat_minor": 5
145
+ }
Experiments/multimodal_exp.ipynb ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 23,
6
+ "id": "d4bed9ef-4bff-4d61-a4f9-a585f377f136",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from PIL import Image\n",
11
+ "import requests\n",
12
+ "\n",
13
+ "import torch\n",
14
+ "from torch import nn\n",
15
+ "from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
16
+ "from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
17
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
18
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
19
+ "from typing import Optional, Union, Tuple"
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "code",
24
+ "execution_count": 43,
25
+ "id": "952314f0-ee9d-45e7-85b8-1e3e44c1a2fd",
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "class VisionLanguageConnector(nn.Module):\n",
30
+ " def __init__(self, hidden_size, projection_dim):\n",
31
+ " super().__init__()\n",
32
+ " self.mlp = nn.Sequential(\n",
33
+ " nn.Linear(hidden_size, hidden_size, bias=False),\n",
34
+ " nn.GELU(),\n",
35
+ " nn.Linear(hidden_size, projection_dim, bias=False)\n",
36
+ " )\n",
37
+ "\n",
38
+ " def forward(self, x):\n",
39
+ " return self.mlp(x)\n",
40
+ " \n",
41
+ "class ClipWithProjection():\n",
42
+ " config_class = CLIPVisionConfig\n",
43
+ " main_input_name = \"pixel_values\"\n",
44
+ "\n",
45
+ " def __init__(self, hidden_size, projection_dim):\n",
46
+ " super().__init__()\n",
47
+ " \n",
48
+ " self.processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
49
+ " self.vision_model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
50
+ " self.vision_language_connector = VisionLanguageConnector(hidden_size, projection_dim)\n",
51
+ "\n",
52
+ " def forward(\n",
53
+ " self,\n",
54
+ " image = None,\n",
55
+ " output_attentions: Optional[bool] = None,\n",
56
+ " output_hidden_states: Optional[bool] = None,\n",
57
+ " return_dict: Optional[bool] = None,\n",
58
+ " ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
59
+ " \n",
60
+ " pixel_values = self.processor(images=image, return_tensors=\"pt\")[\"pixel_values\"]\n",
61
+ " vision_outputs = self.vision_model(\n",
62
+ " pixel_values=pixel_values,\n",
63
+ " output_attentions=output_attentions,\n",
64
+ " output_hidden_states=output_hidden_states,\n",
65
+ " return_dict=return_dict,\n",
66
+ " )\n",
67
+ "\n",
68
+ " pooled_output = vision_outputs[1] # pooled_output\n",
69
+ "\n",
70
+ " image_embeds = self.vision_language_connector(pooled_output)\n",
71
+ "\n",
72
+ " return CLIPVisionModelOutput(\n",
73
+ " image_embeds=image_embeds,\n",
74
+ " last_hidden_state=vision_outputs.last_hidden_state,\n",
75
+ " hidden_states=vision_outputs.hidden_states,\n",
76
+ " attentions=vision_outputs.attentions,\n",
77
+ " )"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 44,
83
+ "id": "bd2889fe-be85-44a3-afe8-65b47f7a93c3",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
88
+ "image = Image.open(requests.get(url, stream=True).raw)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 46,
94
+ "id": "17c72699-fe98-4b96-b63c-5c8ab7c1a65f",
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "# model = ClipWithProjection(768, 512)\n",
99
+ "# model.forward(image)"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": 47,
105
+ "id": "70806156-38a9-45a2-bf9f-e72047a0173f",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "class AudioLanguageConnector:\n",
110
+ " def __init__(self, projection_dim):\n",
111
+ " model_name = \"microsoft/phi-2\"\n",
112
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
113
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
114
+ " self.phi2_tokenizer.max_length = projection_dim\n",
115
+ "\n",
116
+ " def __call__(self, text):\n",
117
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
118
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
119
+ " return tokens\n",
120
+ " \n",
121
+ "\n",
122
+ "class WhisperWithProjection:\n",
123
+ " def __init__(self, projection_dim):\n",
124
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
125
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
126
+ " self.model.config.forced_decoder_ids = None\n",
127
+ " self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
128
+ " \n",
129
+ " def forward(self, audio):\n",
130
+ " input_features = self.processor(audio[\"array\"],\n",
131
+ " sampling_rate=audio[\"sampling_rate\"],\n",
132
+ " return_tensors=\"pt\").input_features\n",
133
+ " # generate token ids\n",
134
+ " predicted_ids = self.model.generate(input_features)\n",
135
+ " # decode token ids to text \n",
136
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
137
+ "\n",
138
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
139
+ " return audio_embeddings"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": 48,
145
+ "id": "79cc4d98-498b-4042-bd71-143b2477733d",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "class TextModality:\n",
150
+ " def __init__(self, projection_dim):\n",
151
+ " model_name = \"microsoft/phi-2\"\n",
152
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
153
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
154
+ " self.phi2_tokenizer.max_length = projection_dim\n",
155
+ "\n",
156
+ "\n",
157
+ " def __call__(self, text):\n",
158
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
159
+ " return tokens"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": 77,
165
+ "id": "ba4c4772-923f-48e8-a4af-b7d9c192dd4b",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "class MultiModalPhi2:\n",
170
+ " def __init__(self):\n",
171
+ " self.text_modality = TextModality(projection_dim=768)\n",
172
+ " self.whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
173
+ " self.clip_w_proj = ClipWithProjection(hidden_size=768, projection_dim=768)\n",
174
+ " self.llm = self.load_llm()\n",
175
+ "\n",
176
+ " def load_llm(self):\n",
177
+ " model_name = \"microsoft/phi-2\"\n",
178
+ " \n",
179
+ " bnb_config = BitsAndBytesConfig(\n",
180
+ " load_in_4bit=True,\n",
181
+ " bnb_4bit_quant_type=\"nf4\",\n",
182
+ " bnb_4bit_compute_dtype=torch.float16)\n",
183
+ " \n",
184
+ " model = AutoModelForCausalLM.from_pretrained(\n",
185
+ " model_name,\n",
186
+ " quantization_config=bnb_config,\n",
187
+ " trust_remote_code=True,\n",
188
+ " device_map=\"cuda:0\"\n",
189
+ " )\n",
190
+ " model.config.use_cache = False\n",
191
+ " return model\n",
192
+ "\n",
193
+ " def forward(self, audio, image, text):\n",
194
+ " if text is not None:\n",
195
+ " text_embed = self.text_modality(text)[\"input_ids\"]\n",
196
+ " if audio is not None:\n",
197
+ " audio_embed = self.whisper_w_proj.forward(audio)[\"input_ids\"]\n",
198
+ " if image is not None:\n",
199
+ " image_embed = self.clip_w_proj.forward(image)[0]\n",
200
+ " print(text_embed.shape, text_embed.dtype)\n",
201
+ " print(audio_embed.shape, audio_embed.dtype)\n",
202
+ " print(image_embed.shape, image_embed.dtype)\n",
203
+ " \n",
204
+ " inputs = torch.concat([text_embed, audio_embed, image_embed], dim=1)\n",
205
+ " print(inputs.shape, inputs.dtype)\n",
206
+ " outputs = self.llm(inputs)\n",
207
+ "\n",
208
+ " return outputs \n",
209
+ " \n",
210
+ "\n",
211
+ " def generate(self, audio, text):\n",
212
+ " text_embeddings = self.text_modality(text)\n",
213
+ " audio_embeddings = self.whisper_w_proj.forward(audio)\n",
214
+ " inputs = torch.concat([text_embed[\"input_ids\"], audio_embed[\"input_ids\"]], dim=1)\n",
215
+ " \n",
216
+ " outputs = self.llm.generate(inputs, max_length=200)\n",
217
+ " text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
218
+ " print(text)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 74,
224
+ "id": "7ca694eb-8009-4eb9-9a4c-eac406ab9584",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "from datasets import load_dataset\n",
229
+ "audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
230
+ "audio = audio_ds[0][\"audio\"]"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 58,
236
+ "id": "37be28c5-4cc3-4471-b394-032c7602accc",
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "text = \"explain about the audio\""
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 59,
246
+ "id": "c0705114-1670-4937-bc3e-3660e5a5d2c5",
247
+ "metadata": {},
248
+ "outputs": [],
249
+ "source": [
250
+ "# image"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 78,
256
+ "id": "0d7e5b49-b4bd-477c-87b8-91ef70857677",
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "name": "stderr",
261
+ "output_type": "stream",
262
+ "text": [
263
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
264
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
265
+ ]
266
+ },
267
+ {
268
+ "data": {
269
+ "application/vnd.jupyter.widget-view+json": {
270
+ "model_id": "733dc7b2208b4853a89aea49bff9a55c",
271
+ "version_major": 2,
272
+ "version_minor": 0
273
+ },
274
+ "text/plain": [
275
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
276
+ ]
277
+ },
278
+ "metadata": {},
279
+ "output_type": "display_data"
280
+ }
281
+ ],
282
+ "source": [
283
+ "model = MultiModalPhi2()"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": 79,
289
+ "id": "0b6471c4-4553-47f3-b38f-46057dcf80f2",
290
+ "metadata": {},
291
+ "outputs": [
292
+ {
293
+ "name": "stdout",
294
+ "output_type": "stream",
295
+ "text": [
296
+ "torch.Size([1, 5]) torch.int64\n",
297
+ "torch.Size([1, 33]) torch.int64\n",
298
+ "torch.Size([1, 768]) torch.float32\n",
299
+ "torch.Size([1, 806]) torch.float32\n"
300
+ ]
301
+ },
302
+ {
303
+ "ename": "RuntimeError",
304
+ "evalue": "Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)",
305
+ "output_type": "error",
306
+ "traceback": [
307
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
308
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
309
+ "Cell \u001b[0;32mIn[79], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43maudio\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
310
+ "Cell \u001b[0;32mIn[77], line 38\u001b[0m, in \u001b[0;36mMultiModalPhi2.forward\u001b[0;34m(self, audio, image, text)\u001b[0m\n\u001b[1;32m 36\u001b[0m inputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mconcat([text_embed, audio_embed, image_embed], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs\u001b[38;5;241m.\u001b[39mshape, inputs\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m---> 38\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
311
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
312
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
313
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
314
+ "File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:1049\u001b[0m, in \u001b[0;36mPhiForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1046\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1049\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1056\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1058\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1059\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1062\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
315
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
316
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
317
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
318
+ "File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:893\u001b[0m, in \u001b[0;36mPhiModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 890\u001b[0m position_ids \u001b[38;5;241m=\u001b[39m position_ids\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 892\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 893\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 895\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_dropout(inputs_embeds)\n\u001b[1;32m 897\u001b[0m \u001b[38;5;66;03m# Attention mask.\u001b[39;00m\n",
319
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
320
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
321
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
322
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
323
+ "File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/functional.py:2233\u001b[0m, in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2227\u001b[0m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[1;32m 2228\u001b[0m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[1;32m 2232\u001b[0m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[0;32m-> 2233\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
324
+ "\u001b[0;31mRuntimeError\u001b[0m: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)"
325
+ ]
326
+ }
327
+ ],
328
+ "source": [
329
+ "model.forward(audio, image, text)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "4ca96caf-82e2-4f07-87b3-8654dfdc89aa",
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": []
339
+ }
340
+ ],
341
+ "metadata": {
342
+ "kernelspec": {
343
+ "display_name": "Python 3 (ipykernel)",
344
+ "language": "python",
345
+ "name": "python3"
346
+ },
347
+ "language_info": {
348
+ "codemirror_mode": {
349
+ "name": "ipython",
350
+ "version": 3
351
+ },
352
+ "file_extension": ".py",
353
+ "mimetype": "text/x-python",
354
+ "name": "python",
355
+ "nbconvert_exporter": "python",
356
+ "pygments_lexer": "ipython3",
357
+ "version": "3.10.12"
358
+ }
359
+ },
360
+ "nbformat": 4,
361
+ "nbformat_minor": 5
362
+ }
Experiments/pretrain_data_check.ipynb ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 5,
6
+ "id": "61c272f2-edbe-4b7d-8fec-3ab431400cd3",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import json"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": 2,
16
+ "id": "e9dfd7d7-1685-4fc7-bbb9-3905c32d8ba1",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "with open(\"metadata.json\", \"rb\") as f:\n",
21
+ " metadata = json.load(f)"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 4,
27
+ "id": "70bdba48-db01-42ac-8d89-edc69d7d7672",
28
+ "metadata": {},
29
+ "outputs": [
30
+ {
31
+ "data": {
32
+ "text/plain": [
33
+ "595375"
34
+ ]
35
+ },
36
+ "execution_count": 4,
37
+ "metadata": {},
38
+ "output_type": "execute_result"
39
+ }
40
+ ],
41
+ "source": [
42
+ "len(metadata)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 14,
48
+ "id": "59e193cc-0dd8-4f7e-959a-fbad0133d76c",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "with open(\"blip_laion_cc_sbu_558k.jsonblip_laion_cc_sbu_558k.json\", \"rb\") as f:\n",
53
+ " data = json.load(f)"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 7,
59
+ "id": "f3157f41-269b-4f7a-b3ba-9be711babe02",
60
+ "metadata": {},
61
+ "outputs": [
62
+ {
63
+ "data": {
64
+ "text/plain": [
65
+ "{'id': '004539375',\n",
66
+ " 'image': '00453/004539375.jpg',\n",
67
+ " 'conversations': [{'from': 'human',\n",
68
+ " 'value': 'Render a clear and concise summary of the photo.\\n<image>'},\n",
69
+ " {'from': 'gpt',\n",
70
+ " 'value': 'select luxury furniture 3 - inch gel memory foam mattress topper'}]}"
71
+ ]
72
+ },
73
+ "execution_count": 7,
74
+ "metadata": {},
75
+ "output_type": "execute_result"
76
+ }
77
+ ],
78
+ "source": [
79
+ "data[0]"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type": "code",
84
+ "execution_count": 8,
85
+ "id": "50d8a051-1526-47dd-ad71-d3c66f7bd34e",
86
+ "metadata": {},
87
+ "outputs": [
88
+ {
89
+ "data": {
90
+ "text/plain": [
91
+ "{'id': '004374662',\n",
92
+ " 'image': '00437/004374662.jpg',\n",
93
+ " 'conversations': [{'from': 'human',\n",
94
+ " 'value': 'Give a brief description of the image.\\n<image>'},\n",
95
+ " {'from': 'gpt', 'value': 'the north face duffel bag camo large'}]}"
96
+ ]
97
+ },
98
+ "execution_count": 8,
99
+ "metadata": {},
100
+ "output_type": "execute_result"
101
+ }
102
+ ],
103
+ "source": [
104
+ "data[234]"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 17,
110
+ "id": "2e6d5664-4583-49a6-93cc-079ee2d1ff6c",
111
+ "metadata": {},
112
+ "outputs": [
113
+ {
114
+ "data": {
115
+ "text/plain": [
116
+ "558128"
117
+ ]
118
+ },
119
+ "execution_count": 17,
120
+ "metadata": {},
121
+ "output_type": "execute_result"
122
+ }
123
+ ],
124
+ "source": [
125
+ "len(data)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 10,
131
+ "id": "11ed106d-6bef-482c-a456-5eaaf2025534",
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "data": {
136
+ "text/plain": [
137
+ "{'id': 'GCC_train_001749371',\n",
138
+ " 'image': 'GCC_train_001749371.jpg',\n",
139
+ " 'caption': 'if you are dreaming of simpler or off - the - grid living , a yurt is a fantastic option',\n",
140
+ " 'blip_caption': 'a white and tan yurt sitting on a dirt road',\n",
141
+ " 'url': 'https://i.pinimg.com/736x/14/7b/64/147b64467ee966d9a578097bb70475ad--yurt-kits-small-space-living.jpg'}"
142
+ ]
143
+ },
144
+ "execution_count": 10,
145
+ "metadata": {},
146
+ "output_type": "execute_result"
147
+ }
148
+ ],
149
+ "source": [
150
+ "metadata[67]"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 15,
156
+ "id": "ce8adcec-2499-4be3-be1d-7313fe54e96a",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "data": {
161
+ "text/plain": [
162
+ "{'id': '000466761',\n",
163
+ " 'image': '00046/000466761.jpg',\n",
164
+ " 'conversations': [{'from': 'human',\n",
165
+ " 'value': '<image>\\nProvide a brief description of the given image.'},\n",
166
+ " {'from': 'gpt',\n",
167
+ " 'value': 'a clipboard and a pen with the words public health emergency next to it on a white table'}]}"
168
+ ]
169
+ },
170
+ "execution_count": 15,
171
+ "metadata": {},
172
+ "output_type": "execute_result"
173
+ }
174
+ ],
175
+ "source": [
176
+ "data[67]"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": 16,
182
+ "id": "068313b6-6379-4ca2-892c-682634d3581e",
183
+ "metadata": {},
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/plain": [
188
+ "list"
189
+ ]
190
+ },
191
+ "execution_count": 16,
192
+ "metadata": {},
193
+ "output_type": "execute_result"
194
+ }
195
+ ],
196
+ "source": [
197
+ "type(data)"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": 24,
203
+ "id": "9ec33b51-4a0b-4a1e-81f7-2fda7cddb25f",
204
+ "metadata": {},
205
+ "outputs": [],
206
+ "source": [
207
+ "sample_data = data[:200000]"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 25,
213
+ "id": "095685e5-40f1-4d84-8280-ef74fa56c5a2",
214
+ "metadata": {},
215
+ "outputs": [
216
+ {
217
+ "data": {
218
+ "text/plain": [
219
+ "200000"
220
+ ]
221
+ },
222
+ "execution_count": 25,
223
+ "metadata": {},
224
+ "output_type": "execute_result"
225
+ }
226
+ ],
227
+ "source": [
228
+ "len(sample_data)"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "code",
233
+ "execution_count": 26,
234
+ "id": "ffbad552-23fd-475f-8e9a-7118bcc4f51e",
235
+ "metadata": {},
236
+ "outputs": [],
237
+ "source": [
238
+ "with open(\"llava-phi/pretrain_data/blip_sample.json\", \"w\") as f:\n",
239
+ " json.dump(sample_data, f)"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 27,
245
+ "id": "69a05d25-6f3b-40c0-a3b5-e185ff526471",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "with open(\"llava-phi/pretrain_data/blip_sample.json\", \"rb\") as f:\n",
250
+ " sample = json.load(f)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 28,
256
+ "id": "200eea06-dfd6-4b3a-bb91-82af7d363951",
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "data": {
261
+ "text/plain": [
262
+ "200000"
263
+ ]
264
+ },
265
+ "execution_count": 28,
266
+ "metadata": {},
267
+ "output_type": "execute_result"
268
+ }
269
+ ],
270
+ "source": [
271
+ "len(sample)"
272
+ ]
273
+ },
274
+ {
275
+ "cell_type": "code",
276
+ "execution_count": null,
277
+ "id": "f86caa1e-edea-4a9c-934f-5420ede80d0d",
278
+ "metadata": {},
279
+ "outputs": [],
280
+ "source": []
281
+ }
282
+ ],
283
+ "metadata": {
284
+ "kernelspec": {
285
+ "display_name": "Python 3 (ipykernel)",
286
+ "language": "python",
287
+ "name": "python3"
288
+ },
289
+ "language_info": {
290
+ "codemirror_mode": {
291
+ "name": "ipython",
292
+ "version": 3
293
+ },
294
+ "file_extension": ".py",
295
+ "mimetype": "text/x-python",
296
+ "name": "python",
297
+ "nbconvert_exporter": "python",
298
+ "pygments_lexer": "ipython3",
299
+ "version": "3.10.12"
300
+ }
301
+ },
302
+ "nbformat": 4,
303
+ "nbformat_minor": 5
304
+ }
Experiments/whispher_exp.ipynb ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 9,
6
+ "id": "bb4dd66b-0c17-48d4-9d34-f48cece2feb5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "# !pip install soundfile\n",
11
+ "# !pip install librosa"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 1,
17
+ "id": "6e9386ea-4862-4f5b-a02f-d656e1a5ab9e",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
22
+ "from datasets import load_dataset"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 2,
28
+ "id": "914ab2b4-389d-4c48-8d1d-1250356646ac",
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "# load model and processor\n",
33
+ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
34
+ "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
35
+ "model.config.forced_decoder_ids = None\n",
36
+ "\n",
37
+ "# load dummy dataset and read audio files\n",
38
+ "ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
39
+ "sample = ds[0][\"audio\"]"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 3,
45
+ "id": "2b299bab-1228-48d9-a8a5-3d5b6c52162d",
46
+ "metadata": {},
47
+ "outputs": [
48
+ {
49
+ "data": {
50
+ "text/plain": [
51
+ "{'path': '/home/ravi.naik/.cache/huggingface/datasets/downloads/extracted/431c2c946d216530b2666a0e7ffa5ac3f5b3da89dd28858a9de6c78fae7caa4a/dev_clean/1272/128104/1272-128104-0000.flac',\n",
52
+ " 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,\n",
53
+ " 0.0010376 ]),\n",
54
+ " 'sampling_rate': 16000}"
55
+ ]
56
+ },
57
+ "execution_count": 3,
58
+ "metadata": {},
59
+ "output_type": "execute_result"
60
+ }
61
+ ],
62
+ "source": [
63
+ "sample"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 4,
69
+ "id": "b7e570a1-cf5c-450c-a7b6-49b45a10d2df",
70
+ "metadata": {},
71
+ "outputs": [],
72
+ "source": [
73
+ "input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features "
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": 5,
79
+ "id": "584e920b-a7fd-402d-95dd-3b9128cd34bb",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "# generate token ids\n",
84
+ "predicted_ids = model.generate(input_features)\n",
85
+ "# decode token ids to text\n",
86
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
87
+ "\n",
88
+ "transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": 6,
94
+ "id": "b27ab660-861b-49d1-81f9-f51cb7f9d8d8",
95
+ "metadata": {},
96
+ "outputs": [
97
+ {
98
+ "data": {
99
+ "text/plain": [
100
+ "[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']"
101
+ ]
102
+ },
103
+ "execution_count": 6,
104
+ "metadata": {},
105
+ "output_type": "execute_result"
106
+ }
107
+ ],
108
+ "source": [
109
+ "transcription"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 3,
115
+ "id": "eca553b8-68f6-493d-b567-3d526b49ae1b",
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "import torch\n",
120
+ "from torch import nn"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 4,
126
+ "id": "c619a4cf-9068-4e4d-8139-e16d15345f4f",
127
+ "metadata": {},
128
+ "outputs": [],
129
+ "source": [
130
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "execution_count": 5,
136
+ "id": "47d5b1ff-ab0f-4d11-af64-d2fa2be39286",
137
+ "metadata": {},
138
+ "outputs": [
139
+ {
140
+ "name": "stderr",
141
+ "output_type": "stream",
142
+ "text": [
143
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
144
+ ]
145
+ }
146
+ ],
147
+ "source": [
148
+ "model_name = \"microsoft/phi-2\"\n",
149
+ "phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
150
+ "phi2_tokenizer.pad_token = phi2_tokenizer.eos_token"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 6,
156
+ "id": "0b36b3f0-db5b-4029-9072-0a53bcab315a",
157
+ "metadata": {},
158
+ "outputs": [
159
+ {
160
+ "ename": "NameError",
161
+ "evalue": "name 'transcription' is not defined",
162
+ "output_type": "error",
163
+ "traceback": [
164
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
165
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
166
+ "Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tokens \u001b[38;5;241m=\u001b[39m phi2_tokenizer(\u001b[38;5;241m*\u001b[39m\u001b[43mtranscription\u001b[49m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
167
+ "\u001b[0;31mNameError\u001b[0m: name 'transcription' is not defined"
168
+ ]
169
+ }
170
+ ],
171
+ "source": [
172
+ "tokens = phi2_tokenizer(*transcription, return_tensors=\"pt\", return_attention_mask=False)"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "code",
177
+ "execution_count": 22,
178
+ "id": "91f6d3d3-bb00-434f-a91e-6952375890d0",
179
+ "metadata": {},
180
+ "outputs": [
181
+ {
182
+ "data": {
183
+ "text/plain": [
184
+ "{'input_ids': tensor([[ 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262,\n",
185
+ " 3504, 6097, 290, 356, 389, 9675, 284, 7062, 465, 21443,\n",
186
+ " 13]])}"
187
+ ]
188
+ },
189
+ "execution_count": 22,
190
+ "metadata": {},
191
+ "output_type": "execute_result"
192
+ }
193
+ ],
194
+ "source": [
195
+ "tokens"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": 12,
201
+ "id": "533191d9-4b3b-417a-918d-6fe854f24b50",
202
+ "metadata": {},
203
+ "outputs": [
204
+ {
205
+ "name": "stderr",
206
+ "output_type": "stream",
207
+ "text": [
208
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
209
+ "- configuration_phi.py\n",
210
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
211
+ ]
212
+ },
213
+ {
214
+ "data": {
215
+ "application/vnd.jupyter.widget-view+json": {
216
+ "model_id": "2a65a119388b4cb4b123b532176e786e",
217
+ "version_major": 2,
218
+ "version_minor": 0
219
+ },
220
+ "text/plain": [
221
+ "modeling_phi.py: 0%| | 0.00/62.7k [00:00<?, ?B/s]"
222
+ ]
223
+ },
224
+ "metadata": {},
225
+ "output_type": "display_data"
226
+ },
227
+ {
228
+ "name": "stderr",
229
+ "output_type": "stream",
230
+ "text": [
231
+ "A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
232
+ "- modeling_phi.py\n",
233
+ ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
234
+ ]
235
+ },
236
+ {
237
+ "data": {
238
+ "application/vnd.jupyter.widget-view+json": {
239
+ "model_id": "7183811844304c16b72d53fe11098a74",
240
+ "version_major": 2,
241
+ "version_minor": 0
242
+ },
243
+ "text/plain": [
244
+ "Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
245
+ ]
246
+ },
247
+ "metadata": {},
248
+ "output_type": "display_data"
249
+ },
250
+ {
251
+ "data": {
252
+ "application/vnd.jupyter.widget-view+json": {
253
+ "model_id": "3e78fe144e8f42139a4d7a1830dbf192",
254
+ "version_major": 2,
255
+ "version_minor": 0
256
+ },
257
+ "text/plain": [
258
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
259
+ ]
260
+ },
261
+ "metadata": {},
262
+ "output_type": "display_data"
263
+ }
264
+ ],
265
+ "source": [
266
+ "bnb_config = BitsAndBytesConfig(\n",
267
+ " load_in_4bit=True,\n",
268
+ " bnb_4bit_quant_type=\"nf4\",\n",
269
+ " bnb_4bit_compute_dtype=torch.float16,\n",
270
+ ")\n",
271
+ "\n",
272
+ "model = AutoModelForCausalLM.from_pretrained(\n",
273
+ " model_name,\n",
274
+ " quantization_config=bnb_config,\n",
275
+ " trust_remote_code=True,\n",
276
+ " device_map=\"cuda:0\"\n",
277
+ ")\n",
278
+ "model.config.use_cache = False"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "execution_count": 19,
284
+ "id": "155c054a-a00f-4ed5-bfff-1ad64889e7f1",
285
+ "metadata": {},
286
+ "outputs": [
287
+ {
288
+ "data": {
289
+ "text/plain": [
290
+ "[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.\\n']"
291
+ ]
292
+ },
293
+ "execution_count": 19,
294
+ "metadata": {},
295
+ "output_type": "execute_result"
296
+ }
297
+ ],
298
+ "source": [
299
+ "phi2_tokenizer.batch_decode(model.generate(**tokens))"
300
+ ]
301
+ },
302
+ {
303
+ "cell_type": "code",
304
+ "execution_count": 7,
305
+ "id": "04f940c9-586d-4937-ae31-cc0f96d33e92",
306
+ "metadata": {},
307
+ "outputs": [],
308
+ "source": [
309
+ "class AudioLanguageConnector:\n",
310
+ " def __init__(self):\n",
311
+ " model_name = \"microsoft/phi-2\"\n",
312
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
313
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
314
+ "\n",
315
+ " def __call__(self, text):\n",
316
+ " text = f\"<audio_start> {text} <audio_end>\"\n",
317
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
318
+ " return tokens\n",
319
+ " \n",
320
+ "\n",
321
+ "class WhisperWithProjection:\n",
322
+ " def __init__(self):\n",
323
+ " self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
324
+ " self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
325
+ " self.model.config.forced_decoder_ids = None\n",
326
+ " self.audio_language_connector = AudioLanguageConnector()\n",
327
+ " \n",
328
+ " def forward(self, audio):\n",
329
+ " input_features = self.processor(audio[\"array\"],\n",
330
+ " sampling_rate=audio[\"sampling_rate\"],\n",
331
+ " return_tensors=\"pt\").input_features\n",
332
+ " # generate token ids\n",
333
+ " predicted_ids = self.model.generate(input_features)\n",
334
+ " # decode token ids to text \n",
335
+ " transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
336
+ "\n",
337
+ " audio_embeddings = self.audio_language_connector(transcription)\n",
338
+ " return audio_embeddings"
339
+ ]
340
+ },
341
+ {
342
+ "cell_type": "code",
343
+ "execution_count": 8,
344
+ "id": "2b1f8f44-bfe6-413c-9e32-c38fa5517981",
345
+ "metadata": {},
346
+ "outputs": [],
347
+ "source": [
348
+ "class TextModality:\n",
349
+ " def __init__(self):\n",
350
+ " model_name = \"microsoft/phi-2\"\n",
351
+ " self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
352
+ " self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
353
+ "\n",
354
+ " def __call__(self, text):\n",
355
+ " tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
356
+ " return tokens"
357
+ ]
358
+ },
359
+ {
360
+ "cell_type": "code",
361
+ "execution_count": 15,
362
+ "id": "21c51648-abb6-4bbd-b4c1-509967a69337",
363
+ "metadata": {},
364
+ "outputs": [],
365
+ "source": [
366
+ "class MultiModalPhi2:\n",
367
+ " def __init__(self):\n",
368
+ " self.text_modality = TextModality()\n",
369
+ " self.whisper_w_proj = WhisperWithProjection()\n",
370
+ " self.llm = self.load_llm()\n",
371
+ "\n",
372
+ " def load_llm(self):\n",
373
+ " bnb_config = BitsAndBytesConfig(\n",
374
+ " load_in_4bit=True,\n",
375
+ " bnb_4bit_quant_type=\"nf4\",\n",
376
+ " bnb_4bit_compute_dtype=torch.float16)\n",
377
+ " \n",
378
+ " model = AutoModelForCausalLM.from_pretrained(\n",
379
+ " model_name,\n",
380
+ " quantization_config=bnb_config,\n",
381
+ " trust_remote_code=True,\n",
382
+ " device_map=\"cuda:0\"\n",
383
+ " )\n",
384
+ " model.config.use_cache = False\n",
385
+ " return model\n",
386
+ "\n",
387
+ " def generate(self, audio, text):\n",
388
+ " text_embeddings = self.text_modality(text)\n",
389
+ " audio_embeddings = self.whisper_w_proj.forward(audio)\n",
390
+ " inputs = torch.concat([text_embeddings[\"input_ids\"], audio_embeddings[\"input_ids\"]], dim=1)\n",
391
+ " \n",
392
+ " # outputs = self.llm.generate(inputs, max_length=200)\n",
393
+ " outputs = self.llm(inputs)\n",
394
+ " return outputs\n",
395
+ " \n",
396
+ " # text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
397
+ " # print(text)"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 16,
403
+ "id": "472a00cb-bae9-4c09-a0ef-bc57881b5e2c",
404
+ "metadata": {},
405
+ "outputs": [
406
+ {
407
+ "name": "stderr",
408
+ "output_type": "stream",
409
+ "text": [
410
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
411
+ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
412
+ ]
413
+ },
414
+ {
415
+ "data": {
416
+ "application/vnd.jupyter.widget-view+json": {
417
+ "model_id": "2236e6b1e26d444fa3d48181ba1a6cf9",
418
+ "version_major": 2,
419
+ "version_minor": 0
420
+ },
421
+ "text/plain": [
422
+ "Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
423
+ ]
424
+ },
425
+ "metadata": {},
426
+ "output_type": "display_data"
427
+ }
428
+ ],
429
+ "source": [
430
+ "multi_modal_phi = MultiModalPhi2()"
431
+ ]
432
+ },
433
+ {
434
+ "cell_type": "code",
435
+ "execution_count": 17,
436
+ "id": "c350f2d3-0929-4c46-b63d-ff92dea437f3",
437
+ "metadata": {},
438
+ "outputs": [
439
+ {
440
+ "data": {
441
+ "text/plain": [
442
+ "CausalLMOutputWithPast(loss={'logits': tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
443
+ " [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
444
+ " [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
445
+ " ...,\n",
446
+ " [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
447
+ " [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
448
+ " [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
449
+ " grad_fn=<ToCopyBackward0>)}, logits=tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
450
+ " [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
451
+ " [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
452
+ " ...,\n",
453
+ " [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
454
+ " [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
455
+ " [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
456
+ " grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)"
457
+ ]
458
+ },
459
+ "execution_count": 17,
460
+ "metadata": {},
461
+ "output_type": "execute_result"
462
+ }
463
+ ],
464
+ "source": [
465
+ "audio = sample\n",
466
+ "text = \"explain about the audio\"\n",
467
+ "multi_modal_phi.generate(audio, text)"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "id": "46aa9c66-a5bb-4760-8895-92673f49345f",
474
+ "metadata": {},
475
+ "outputs": [],
476
+ "source": []
477
+ }
478
+ ],
479
+ "metadata": {
480
+ "kernelspec": {
481
+ "display_name": "Python 3 (ipykernel)",
482
+ "language": "python",
483
+ "name": "python3"
484
+ },
485
+ "language_info": {
486
+ "codemirror_mode": {
487
+ "name": "ipython",
488
+ "version": 3
489
+ },
490
+ "file_extension": ".py",
491
+ "mimetype": "text/x-python",
492
+ "name": "python",
493
+ "nbconvert_exporter": "python",
494
+ "pygments_lexer": "ipython3",
495
+ "version": "3.10.12"
496
+ }
497
+ },
498
+ "nbformat": 4,
499
+ "nbformat_minor": 5
500
+ }
README.md CHANGED
@@ -1,13 +1,44 @@
1
  ---
2
  title: MultiModal Phi2
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: MultiModal Phi2
3
+ emoji: 🚀
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 3.35.2
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
12
+ ## Phi2 : Multimodal Finetuning
13
+ ### Details
14
+ 1. LLM Backbone: Phi2
15
+ 2. Vision Tower: clip-vit-large-patch14-336
16
+ 3. Audio Model: Whisper
17
+ 4. Pretraining Dataset: LAION-CC-SBU dataset with BLIP captions(200k samples)
18
+ 5. Finetuning Dataset: Instruct 150k dataset based on COCO
19
 
20
+ ### Design
21
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/56df24cd-2681-4e17-ab64-9652f609b15f)
22
+
23
+ ### Pretraining
24
+ #### Training Loss Curve
25
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/b6c37a95-0a56-4b52-8719-3ff56dc1b703)
26
+
27
+ #### Learing Rate
28
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/44d9a11b-b28d-47e1-ba1d-d6dc22ebe748)
29
+
30
+ #### Training Logs
31
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/76543d98-d9fe-4c1a-ac47-3d06e48053ad)
32
+
33
+ ### Finetuning
34
+ #### Training Loss Curve
35
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/45ef40bd-fae5-4cfe-a522-c0eed2833230)
36
+
37
+ #### Learing Rate
38
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/df60ee62-a537-4e36-a7f7-f7111e101162)
39
+
40
+ #### Training Logs
41
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/2747acce-bc99-4c37-a05a-d5e81cb9aa9d)
42
+
43
+ ### Results
44
+ ![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/f12a9f04-df32-413e-b957-774c30381b2b)
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from inference.main import MultiModalPhi2
4
+
5
+ messages = []
6
+
7
+ multimodal_phi2 = MultiModalPhi2(
8
+ modelname_or_path="RaviNaik/Llava-Phi2",
9
+ temperature=0.2,
10
+ max_new_tokens=1024,
11
+ device="cpu",
12
+ )
13
+
14
+
15
+ def add_content(chatbot, text, image, audio_upload, audio_mic) -> gr.Chatbot:
16
+ textflag, imageflag, audioflag = False, False, False
17
+ if text not in ["", None]:
18
+ chatbot.append((text, None))
19
+ textflag = True
20
+ if image is not None:
21
+ chatbot.append(((image,), None))
22
+ imageflag = True
23
+ if audio_mic is not None:
24
+ chatbot.append(((audio_mic,), None))
25
+ audioflag = True
26
+ else:
27
+ if audio_upload is not None:
28
+ chatbot.append(((audio_upload,), None))
29
+ audioflag = True
30
+ if not any([textflag, imageflag, audioflag]):
31
+ # Raise an error if neither text nor file is provided
32
+ raise gr.Error("Enter a valid text, image or audio")
33
+ return chatbot
34
+
35
+
36
+ def clear_data():
37
+ return {prompt: None, image: None, audio_upload: None, audio_mic: None, chatbot: []}
38
+
39
+
40
+ def run(history, text, image, audio_upload, audio_mic):
41
+ if text in [None, ""]:
42
+ text = None
43
+
44
+ if audio_upload is not None:
45
+ audio = audio_upload
46
+ elif audio_mic is not None:
47
+ audio = audio_mic
48
+ else:
49
+ audio = None
50
+
51
+ print("text", text)
52
+ print("image", image)
53
+ print("audio", audio)
54
+
55
+ if image is not None:
56
+ image = Image.open(image)
57
+ outputs = multimodal_phi2(text, audio, image)
58
+ # outputs = ""
59
+
60
+ history.append((None, outputs.title()))
61
+ return history, None, None, None, None
62
+
63
+
64
+ with gr.Blocks() as demo:
65
+ gr.Markdown("## MulitModal Phi2 Model Pretraining and Finetuning from Scratch")
66
+ gr.Markdown(
67
+ """This is a multimodal implementation of [Phi2](https://huggingface.co/microsoft/phi-2) model.
68
+
69
+ Please find the source code and training details [here](https://github.com/RaviNaik/ERA-CAPSTONE/MultiModalPhi2).
70
+
71
+ ### Details:
72
+ 1. LLM Backbone: [Phi2](https://huggingface.co/microsoft/phi-2)
73
+ 2. Vision Tower: [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336)
74
+ 3. Audio Model: [Whisper Tiny](https://huggingface.co/openai/whisper-tiny)
75
+ 4. Pretraining Dataset: [LAION-CC-SBU dataset with BLIP captions(200k samples)](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain)
76
+ 5. Finetuning Dataset: [Instruct 150k dataset based on COCO](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)
77
+ 6. Finetuned Model: [RaviNaik/Llava-Phi2](https://huggingface.co/RaviNaik/Llava-Phi2)
78
+ """
79
+ )
80
+ with gr.Row():
81
+ with gr.Column(scale=4):
82
+ # Creating a column with a scale of 6
83
+ with gr.Box():
84
+ with gr.Row():
85
+ # Adding a Textbox with a placeholder "write prompt"
86
+ prompt = gr.Textbox(
87
+ placeholder="Enter Prompt", lines=2, label="Query", value=None
88
+ )
89
+ # Creating a column with a scale of 2
90
+ with gr.Row():
91
+ # Adding image
92
+ image = gr.Image(type="filepath", value=None)
93
+ # Creating a column with a scale of 2
94
+ with gr.Row():
95
+ # Add audio
96
+ audio_upload = gr.Audio(source="upload", type="filepath")
97
+ audio_mic = gr.Audio(
98
+ source="microphone", type="filepath", format="mp3"
99
+ )
100
+
101
+ with gr.Column(scale=8):
102
+ with gr.Box():
103
+ with gr.Row():
104
+ chatbot = gr.Chatbot(
105
+ avatar_images=("🧑", "🤖"),
106
+ height=550,
107
+ )
108
+ with gr.Row():
109
+ # Adding a Button
110
+ submit = gr.Button()
111
+ clear = gr.Button(value="Clear")
112
+
113
+ submit.click(
114
+ add_content,
115
+ inputs=[chatbot, prompt, image, audio_upload, audio_mic],
116
+ outputs=[chatbot],
117
+ ).success(
118
+ run,
119
+ inputs=[chatbot, prompt, image, audio_upload, audio_mic],
120
+ outputs=[chatbot, prompt, image, audio_upload, audio_mic],
121
+ )
122
+
123
+ clear.click(
124
+ clear_data,
125
+ outputs=[prompt, image, audio_upload, audio_mic, chatbot],
126
+ )
127
+
128
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.6.1
2
+ einops-exts==0.0.4
3
+ timm==0.6.13
4
+ gradio==3.35.2
5
+ gradio_client==0.2.9
6
+ markdown2[all]
7
+ numpy
8
+ requests
9
+ tokenizers==0.15.0
10
+ torch==2.0.1
11
+ shortuuid
12
+ httpx==0.24.0
13
+ deepspeed==0.9.5
14
+ peft==0.4.0
15
+ transformers==4.36.2
16
+ accelerate==0.21.0
17
+ bitsandbytes==0.41.0
18
+ scikit-learn==1.2.2
19
+ sentencepiece==0.1.99
20
+ librosa
21
+ soundfile