Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- .gitignore +160 -0
- Experiments/clip_expt.ipynb +840 -0
- Experiments/eval.ipynb +782 -0
- Experiments/instruct_150k_data.ipynb +0 -0
- Experiments/instruct_data.py +39 -0
- Experiments/llava_exp.ipynb +145 -0
- Experiments/multimodal_exp.ipynb +362 -0
- Experiments/pretrain_data_check.ipynb +304 -0
- Experiments/whispher_exp.ipynb +500 -0
- README.md +36 -5
- app.py +128 -0
- requirements.txt +21 -0
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
Experiments/clip_expt.ipynb
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 2,
|
6 |
+
"id": "9fe51ce7-4c87-4186-9fd3-0fb18ac43e56",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from PIL import Image\n",
|
11 |
+
"import requests\n",
|
12 |
+
"from transformers import AutoProcessor, CLIPVisionModel"
|
13 |
+
]
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"execution_count": 3,
|
18 |
+
"id": "0f4c21dd-4258-461d-8511-5be089d068a8",
|
19 |
+
"metadata": {},
|
20 |
+
"outputs": [],
|
21 |
+
"source": [
|
22 |
+
"model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")\n",
|
23 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\", device_map=\"cuda:0\")"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"execution_count": 4,
|
29 |
+
"id": "98b9f906-ffaa-4be4-8671-4ecf65f12c49",
|
30 |
+
"metadata": {},
|
31 |
+
"outputs": [],
|
32 |
+
"source": [
|
33 |
+
"# url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
34 |
+
"# image = Image.open(requests.get(url, stream=True).raw)\n",
|
35 |
+
"image = Image.open(\"002579.jpg\")"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 17,
|
41 |
+
"id": "54b2e4ce-b77b-4314-87f6-ca2a1970fc79",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [],
|
44 |
+
"source": [
|
45 |
+
"# image"
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"cell_type": "code",
|
50 |
+
"execution_count": 18,
|
51 |
+
"id": "cdd65c58-007f-450b-8deb-f8b4f372a823",
|
52 |
+
"metadata": {},
|
53 |
+
"outputs": [],
|
54 |
+
"source": [
|
55 |
+
"# image = None"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"execution_count": 5,
|
61 |
+
"id": "e9066c2e-c78b-49d1-979b-10d0f4f09441",
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"inputs = processor(images=image, return_tensors=\"pt\", device_map=\"cuda:0\")"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
{
|
69 |
+
"cell_type": "code",
|
70 |
+
"execution_count": 20,
|
71 |
+
"id": "e98b211d-29d9-4662-be0b-e011e89b0101",
|
72 |
+
"metadata": {},
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"# inputs"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 6,
|
81 |
+
"id": "b030bd3d-4282-4074-98fe-97e658bd0f50",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"data": {
|
86 |
+
"text/plain": [
|
87 |
+
"torch.Size([1, 3, 224, 224])"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
"execution_count": 6,
|
91 |
+
"metadata": {},
|
92 |
+
"output_type": "execute_result"
|
93 |
+
}
|
94 |
+
],
|
95 |
+
"source": [
|
96 |
+
"inputs[\"pixel_values\"].shape"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": 22,
|
102 |
+
"id": "0ce68f11-1c88-4dd7-8b17-0d1de5811fe6",
|
103 |
+
"metadata": {},
|
104 |
+
"outputs": [],
|
105 |
+
"source": [
|
106 |
+
"outputs = model(inputs[\"pixel_values\"].to(\"cuda:0\"))\n",
|
107 |
+
"last_hidden_state = outputs.last_hidden_state\n",
|
108 |
+
"pooled_output = outputs.pooler_output # pooled CLS states"
|
109 |
+
]
|
110 |
+
},
|
111 |
+
{
|
112 |
+
"cell_type": "code",
|
113 |
+
"execution_count": 23,
|
114 |
+
"id": "30cb0918-a30e-4246-b540-6b8e0d876807",
|
115 |
+
"metadata": {},
|
116 |
+
"outputs": [
|
117 |
+
{
|
118 |
+
"data": {
|
119 |
+
"text/plain": [
|
120 |
+
"torch.Size([1, 768])"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
"execution_count": 23,
|
124 |
+
"metadata": {},
|
125 |
+
"output_type": "execute_result"
|
126 |
+
}
|
127 |
+
],
|
128 |
+
"source": [
|
129 |
+
"pooled_output.shape"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"execution_count": 24,
|
135 |
+
"id": "6399543a-f23f-426d-8289-3bb52d293ece",
|
136 |
+
"metadata": {},
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"data": {
|
140 |
+
"text/plain": [
|
141 |
+
"torch.Size([1, 50, 768])"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 24,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"last_hidden_state.shape"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 25,
|
156 |
+
"id": "19a70443-5942-4937-b3ea-6a52d76e2b08",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"text/plain": [
|
162 |
+
"torch.Size([1, 768])"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
"execution_count": 25,
|
166 |
+
"metadata": {},
|
167 |
+
"output_type": "execute_result"
|
168 |
+
}
|
169 |
+
],
|
170 |
+
"source": [
|
171 |
+
"outputs[1].shape"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 8,
|
177 |
+
"id": "fa13903f-a94a-4839-ae5a-8df4f55c68b6",
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [],
|
180 |
+
"source": [
|
181 |
+
"import torch\n",
|
182 |
+
"from torch import nn\n",
|
183 |
+
"from transformers import CLIPVisionConfig,CLIPPreTrainedModel"
|
184 |
+
]
|
185 |
+
},
|
186 |
+
{
|
187 |
+
"cell_type": "code",
|
188 |
+
"execution_count": 9,
|
189 |
+
"id": "b2bd9198-42f0-40c3-80e1-d167c0b038fb",
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [
|
192 |
+
{
|
193 |
+
"ename": "NameError",
|
194 |
+
"evalue": "name 'Optional' is not defined",
|
195 |
+
"output_type": "error",
|
196 |
+
"traceback": [
|
197 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
198 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
199 |
+
"Cell \u001b[0;32mIn[9], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mCLIPVisionModelWithProjection\u001b[39;00m(CLIPPreTrainedModel):\n\u001b[1;32m 2\u001b[0m config_class \u001b[38;5;241m=\u001b[39m CLIPVisionConfig\n\u001b[1;32m 3\u001b[0m main_input_name \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpixel_values\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
|
200 |
+
"Cell \u001b[0;32mIn[9], line 20\u001b[0m, in \u001b[0;36mCLIPVisionModelWithProjection\u001b[0;34m()\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_embeddings\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m nn\u001b[38;5;241m.\u001b[39mModule:\n\u001b[1;32m 16\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model\u001b[38;5;241m.\u001b[39membeddings\u001b[38;5;241m.\u001b[39mpatch_embedding\n\u001b[1;32m 18\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\n\u001b[1;32m 19\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[0;32m---> 20\u001b[0m pixel_values: \u001b[43mOptional\u001b[49m[torch\u001b[38;5;241m.\u001b[39mFloatTensor] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 21\u001b[0m output_attentions: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m output_hidden_states: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 23\u001b[0m return_dict: Optional[\u001b[38;5;28mbool\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 24\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Union[Tuple, CLIPVisionModelOutput]:\n\u001b[1;32m 25\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 27\u001b[0m vision_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mvision_model(\n\u001b[1;32m 28\u001b[0m pixel_values\u001b[38;5;241m=\u001b[39mpixel_values,\n\u001b[1;32m 29\u001b[0m output_attentions\u001b[38;5;241m=\u001b[39moutput_attentions,\n\u001b[1;32m 30\u001b[0m output_hidden_states\u001b[38;5;241m=\u001b[39moutput_hidden_states,\n\u001b[1;32m 31\u001b[0m return_dict\u001b[38;5;241m=\u001b[39mreturn_dict,\n\u001b[1;32m 32\u001b[0m )\n",
|
201 |
+
"\u001b[0;31mNameError\u001b[0m: name 'Optional' is not defined"
|
202 |
+
]
|
203 |
+
}
|
204 |
+
],
|
205 |
+
"source": [
|
206 |
+
"class CLIPVisionModelWithProjection(CLIPPreTrainedModel):\n",
|
207 |
+
" config_class = CLIPVisionConfig\n",
|
208 |
+
" main_input_name = \"pixel_values\"\n",
|
209 |
+
"\n",
|
210 |
+
" def __init__(self, config: CLIPVisionConfig):\n",
|
211 |
+
" super().__init__(config)\n",
|
212 |
+
"\n",
|
213 |
+
" self.vision_model = CLIPVisionTransformer(config)\n",
|
214 |
+
"\n",
|
215 |
+
" self.visual_projection = nn.Linear(config.hidden_size, config.projection_dim, bias=False)\n",
|
216 |
+
"\n",
|
217 |
+
" # Initialize weights and apply final processing\n",
|
218 |
+
" self.post_init()\n",
|
219 |
+
"\n",
|
220 |
+
" def get_input_embeddings(self) -> nn.Module:\n",
|
221 |
+
" return self.vision_model.embeddings.patch_embedding\n",
|
222 |
+
"\n",
|
223 |
+
" def forward(\n",
|
224 |
+
" self,\n",
|
225 |
+
" pixel_values: Optional[torch.FloatTensor] = None,\n",
|
226 |
+
" output_attentions: Optional[bool] = None,\n",
|
227 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
228 |
+
" return_dict: Optional[bool] = None,\n",
|
229 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
230 |
+
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
|
231 |
+
"\n",
|
232 |
+
" vision_outputs = self.vision_model(\n",
|
233 |
+
" pixel_values=pixel_values,\n",
|
234 |
+
" output_attentions=output_attentions,\n",
|
235 |
+
" output_hidden_states=output_hidden_states,\n",
|
236 |
+
" return_dict=return_dict,\n",
|
237 |
+
" )\n",
|
238 |
+
"\n",
|
239 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
240 |
+
"\n",
|
241 |
+
" image_embeds = self.visual_projection(pooled_output)\n",
|
242 |
+
"\n",
|
243 |
+
" if not return_dict:\n",
|
244 |
+
" outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
|
245 |
+
" return tuple(output for output in outputs if output is not None)\n",
|
246 |
+
"\n",
|
247 |
+
" return CLIPVisionModelOutput(\n",
|
248 |
+
" image_embeds=image_embeds,\n",
|
249 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
250 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
251 |
+
" attentions=vision_outputs.attentions,\n",
|
252 |
+
" )"
|
253 |
+
]
|
254 |
+
},
|
255 |
+
{
|
256 |
+
"cell_type": "code",
|
257 |
+
"execution_count": 27,
|
258 |
+
"id": "68a9ee4a-d977-4725-842d-e64e0dd2f61d",
|
259 |
+
"metadata": {
|
260 |
+
"collapsed": true,
|
261 |
+
"jupyter": {
|
262 |
+
"outputs_hidden": true
|
263 |
+
}
|
264 |
+
},
|
265 |
+
"outputs": [
|
266 |
+
{
|
267 |
+
"name": "stderr",
|
268 |
+
"output_type": "stream",
|
269 |
+
"text": [
|
270 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
271 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
272 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
273 |
+
"Model config CLIPConfig {\n",
|
274 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
275 |
+
" \"architectures\": [\n",
|
276 |
+
" \"CLIPModel\"\n",
|
277 |
+
" ],\n",
|
278 |
+
" \"initializer_factor\": 1.0,\n",
|
279 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
280 |
+
" \"model_type\": \"clip\",\n",
|
281 |
+
" \"projection_dim\": 512,\n",
|
282 |
+
" \"text_config\": {\n",
|
283 |
+
" \"bos_token_id\": 0,\n",
|
284 |
+
" \"dropout\": 0.0,\n",
|
285 |
+
" \"eos_token_id\": 2,\n",
|
286 |
+
" \"model_type\": \"clip_text_model\"\n",
|
287 |
+
" },\n",
|
288 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
289 |
+
" \"vision_config\": {\n",
|
290 |
+
" \"dropout\": 0.0,\n",
|
291 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
292 |
+
" }\n",
|
293 |
+
"}\n",
|
294 |
+
"\n",
|
295 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
296 |
+
"All model checkpoint weights were used when initializing CLIPModel.\n",
|
297 |
+
"\n",
|
298 |
+
"All the weights of CLIPModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
|
299 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPModel for predictions without further training.\n",
|
300 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
301 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
302 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
303 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
304 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
305 |
+
"Model config CLIPConfig {\n",
|
306 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
307 |
+
" \"architectures\": [\n",
|
308 |
+
" \"CLIPModel\"\n",
|
309 |
+
" ],\n",
|
310 |
+
" \"initializer_factor\": 1.0,\n",
|
311 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
312 |
+
" \"model_type\": \"clip\",\n",
|
313 |
+
" \"projection_dim\": 512,\n",
|
314 |
+
" \"text_config\": {\n",
|
315 |
+
" \"bos_token_id\": 0,\n",
|
316 |
+
" \"dropout\": 0.0,\n",
|
317 |
+
" \"eos_token_id\": 2,\n",
|
318 |
+
" \"model_type\": \"clip_text_model\"\n",
|
319 |
+
" },\n",
|
320 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
321 |
+
" \"vision_config\": {\n",
|
322 |
+
" \"dropout\": 0.0,\n",
|
323 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
324 |
+
" }\n",
|
325 |
+
"}\n",
|
326 |
+
"\n",
|
327 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
328 |
+
"size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
|
329 |
+
"crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
|
330 |
+
"Image processor CLIPImageProcessor {\n",
|
331 |
+
" \"crop_size\": {\n",
|
332 |
+
" \"height\": 224,\n",
|
333 |
+
" \"width\": 224\n",
|
334 |
+
" },\n",
|
335 |
+
" \"do_center_crop\": true,\n",
|
336 |
+
" \"do_convert_rgb\": true,\n",
|
337 |
+
" \"do_normalize\": true,\n",
|
338 |
+
" \"do_rescale\": true,\n",
|
339 |
+
" \"do_resize\": true,\n",
|
340 |
+
" \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
|
341 |
+
" \"image_mean\": [\n",
|
342 |
+
" 0.48145466,\n",
|
343 |
+
" 0.4578275,\n",
|
344 |
+
" 0.40821073\n",
|
345 |
+
" ],\n",
|
346 |
+
" \"image_processor_type\": \"CLIPImageProcessor\",\n",
|
347 |
+
" \"image_std\": [\n",
|
348 |
+
" 0.26862954,\n",
|
349 |
+
" 0.26130258,\n",
|
350 |
+
" 0.27577711\n",
|
351 |
+
" ],\n",
|
352 |
+
" \"resample\": 3,\n",
|
353 |
+
" \"rescale_factor\": 0.00392156862745098,\n",
|
354 |
+
" \"size\": {\n",
|
355 |
+
" \"shortest_edge\": 224\n",
|
356 |
+
" }\n",
|
357 |
+
"}\n",
|
358 |
+
"\n",
|
359 |
+
"loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
|
360 |
+
"loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
|
361 |
+
"loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
|
362 |
+
"loading file added_tokens.json from cache at None\n",
|
363 |
+
"loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
|
364 |
+
"loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
|
365 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
366 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
367 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
368 |
+
"Model config CLIPConfig {\n",
|
369 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
370 |
+
" \"architectures\": [\n",
|
371 |
+
" \"CLIPModel\"\n",
|
372 |
+
" ],\n",
|
373 |
+
" \"initializer_factor\": 1.0,\n",
|
374 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
375 |
+
" \"model_type\": \"clip\",\n",
|
376 |
+
" \"projection_dim\": 512,\n",
|
377 |
+
" \"text_config\": {\n",
|
378 |
+
" \"bos_token_id\": 0,\n",
|
379 |
+
" \"dropout\": 0.0,\n",
|
380 |
+
" \"eos_token_id\": 2,\n",
|
381 |
+
" \"model_type\": \"clip_text_model\"\n",
|
382 |
+
" },\n",
|
383 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
384 |
+
" \"vision_config\": {\n",
|
385 |
+
" \"dropout\": 0.0,\n",
|
386 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
387 |
+
" }\n",
|
388 |
+
"}\n",
|
389 |
+
"\n"
|
390 |
+
]
|
391 |
+
}
|
392 |
+
],
|
393 |
+
"source": [
|
394 |
+
"from PIL import Image\n",
|
395 |
+
"import requests\n",
|
396 |
+
"from transformers import AutoProcessor, CLIPModel\n",
|
397 |
+
"\n",
|
398 |
+
"model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
399 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
400 |
+
"\n",
|
401 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
402 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
403 |
+
"\n",
|
404 |
+
"inputs = processor(images=image, return_tensors=\"pt\")\n",
|
405 |
+
"\n",
|
406 |
+
"image_features = model.get_image_features(**inputs)"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"cell_type": "code",
|
411 |
+
"execution_count": 29,
|
412 |
+
"id": "9ff63766-b706-452b-b735-bf9000fb9c20",
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [
|
415 |
+
{
|
416 |
+
"data": {
|
417 |
+
"text/plain": [
|
418 |
+
"torch.Size([1, 512])"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
"execution_count": 29,
|
422 |
+
"metadata": {},
|
423 |
+
"output_type": "execute_result"
|
424 |
+
}
|
425 |
+
],
|
426 |
+
"source": [
|
427 |
+
"image_features.shape"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "code",
|
432 |
+
"execution_count": 30,
|
433 |
+
"id": "82566e7b-3c91-421a-94c5-f1e2b3e91c8c",
|
434 |
+
"metadata": {
|
435 |
+
"collapsed": true,
|
436 |
+
"jupyter": {
|
437 |
+
"outputs_hidden": true
|
438 |
+
}
|
439 |
+
},
|
440 |
+
"outputs": [
|
441 |
+
{
|
442 |
+
"name": "stderr",
|
443 |
+
"output_type": "stream",
|
444 |
+
"text": [
|
445 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
446 |
+
"Model config CLIPVisionConfig {\n",
|
447 |
+
" \"attention_dropout\": 0.0,\n",
|
448 |
+
" \"dropout\": 0.0,\n",
|
449 |
+
" \"hidden_act\": \"quick_gelu\",\n",
|
450 |
+
" \"hidden_size\": 768,\n",
|
451 |
+
" \"image_size\": 224,\n",
|
452 |
+
" \"initializer_factor\": 1.0,\n",
|
453 |
+
" \"initializer_range\": 0.02,\n",
|
454 |
+
" \"intermediate_size\": 3072,\n",
|
455 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
456 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
457 |
+
" \"num_attention_heads\": 12,\n",
|
458 |
+
" \"num_channels\": 3,\n",
|
459 |
+
" \"num_hidden_layers\": 12,\n",
|
460 |
+
" \"patch_size\": 32,\n",
|
461 |
+
" \"projection_dim\": 512,\n",
|
462 |
+
" \"transformers_version\": \"4.36.2\"\n",
|
463 |
+
"}\n",
|
464 |
+
"\n",
|
465 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
466 |
+
"Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
|
467 |
+
"- This IS expected if you are initializing CLIPVisionModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
468 |
+
"- This IS NOT expected if you are initializing CLIPVisionModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
469 |
+
"All the weights of CLIPVisionModel were initialized from the model checkpoint at openai/clip-vit-base-patch32.\n",
|
470 |
+
"If your task is similar to the task the model of the checkpoint was trained on, you can already use CLIPVisionModel for predictions without further training.\n",
|
471 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
472 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
473 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
474 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
475 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
476 |
+
"Model config CLIPConfig {\n",
|
477 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
478 |
+
" \"architectures\": [\n",
|
479 |
+
" \"CLIPModel\"\n",
|
480 |
+
" ],\n",
|
481 |
+
" \"initializer_factor\": 1.0,\n",
|
482 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
483 |
+
" \"model_type\": \"clip\",\n",
|
484 |
+
" \"projection_dim\": 512,\n",
|
485 |
+
" \"text_config\": {\n",
|
486 |
+
" \"bos_token_id\": 0,\n",
|
487 |
+
" \"dropout\": 0.0,\n",
|
488 |
+
" \"eos_token_id\": 2,\n",
|
489 |
+
" \"model_type\": \"clip_text_model\"\n",
|
490 |
+
" },\n",
|
491 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
492 |
+
" \"vision_config\": {\n",
|
493 |
+
" \"dropout\": 0.0,\n",
|
494 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
495 |
+
" }\n",
|
496 |
+
"}\n",
|
497 |
+
"\n",
|
498 |
+
"loading configuration file preprocessor_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/preprocessor_config.json\n",
|
499 |
+
"size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'shortest_edge': 224}.\n",
|
500 |
+
"crop_size should be a dictionary on of the following set of keys: ({'width', 'height'}, {'shortest_edge'}, {'longest_edge', 'shortest_edge'}, {'longest_edge'}), got 224. Converted to {'height': 224, 'width': 224}.\n",
|
501 |
+
"Image processor CLIPImageProcessor {\n",
|
502 |
+
" \"crop_size\": {\n",
|
503 |
+
" \"height\": 224,\n",
|
504 |
+
" \"width\": 224\n",
|
505 |
+
" },\n",
|
506 |
+
" \"do_center_crop\": true,\n",
|
507 |
+
" \"do_convert_rgb\": true,\n",
|
508 |
+
" \"do_normalize\": true,\n",
|
509 |
+
" \"do_rescale\": true,\n",
|
510 |
+
" \"do_resize\": true,\n",
|
511 |
+
" \"feature_extractor_type\": \"CLIPFeatureExtractor\",\n",
|
512 |
+
" \"image_mean\": [\n",
|
513 |
+
" 0.48145466,\n",
|
514 |
+
" 0.4578275,\n",
|
515 |
+
" 0.40821073\n",
|
516 |
+
" ],\n",
|
517 |
+
" \"image_processor_type\": \"CLIPImageProcessor\",\n",
|
518 |
+
" \"image_std\": [\n",
|
519 |
+
" 0.26862954,\n",
|
520 |
+
" 0.26130258,\n",
|
521 |
+
" 0.27577711\n",
|
522 |
+
" ],\n",
|
523 |
+
" \"resample\": 3,\n",
|
524 |
+
" \"rescale_factor\": 0.00392156862745098,\n",
|
525 |
+
" \"size\": {\n",
|
526 |
+
" \"shortest_edge\": 224\n",
|
527 |
+
" }\n",
|
528 |
+
"}\n",
|
529 |
+
"\n",
|
530 |
+
"loading file vocab.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/vocab.json\n",
|
531 |
+
"loading file merges.txt from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/merges.txt\n",
|
532 |
+
"loading file tokenizer.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer.json\n",
|
533 |
+
"loading file added_tokens.json from cache at None\n",
|
534 |
+
"loading file special_tokens_map.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/special_tokens_map.json\n",
|
535 |
+
"loading file tokenizer_config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/tokenizer_config.json\n",
|
536 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
537 |
+
"`text_config` is `None`. Initializing the `CLIPTextConfig` with default values.\n",
|
538 |
+
"`vision_config` is `None`. initializing the `CLIPVisionConfig` with default values.\n",
|
539 |
+
"Model config CLIPConfig {\n",
|
540 |
+
" \"_name_or_path\": \"openai/clip-vit-base-patch32\",\n",
|
541 |
+
" \"architectures\": [\n",
|
542 |
+
" \"CLIPModel\"\n",
|
543 |
+
" ],\n",
|
544 |
+
" \"initializer_factor\": 1.0,\n",
|
545 |
+
" \"logit_scale_init_value\": 2.6592,\n",
|
546 |
+
" \"model_type\": \"clip\",\n",
|
547 |
+
" \"projection_dim\": 512,\n",
|
548 |
+
" \"text_config\": {\n",
|
549 |
+
" \"bos_token_id\": 0,\n",
|
550 |
+
" \"dropout\": 0.0,\n",
|
551 |
+
" \"eos_token_id\": 2,\n",
|
552 |
+
" \"model_type\": \"clip_text_model\"\n",
|
553 |
+
" },\n",
|
554 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
555 |
+
" \"vision_config\": {\n",
|
556 |
+
" \"dropout\": 0.0,\n",
|
557 |
+
" \"model_type\": \"clip_vision_model\"\n",
|
558 |
+
" }\n",
|
559 |
+
"}\n",
|
560 |
+
"\n"
|
561 |
+
]
|
562 |
+
}
|
563 |
+
],
|
564 |
+
"source": [
|
565 |
+
"from PIL import Image\n",
|
566 |
+
"import requests\n",
|
567 |
+
"from transformers import AutoProcessor, CLIPVisionModel\n",
|
568 |
+
"\n",
|
569 |
+
"model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
570 |
+
"processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
571 |
+
"\n",
|
572 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
573 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
574 |
+
"\n",
|
575 |
+
"inputs = processor(images=image, return_tensors=\"pt\")\n",
|
576 |
+
"\n",
|
577 |
+
"outputs = model(**inputs)\n",
|
578 |
+
"last_hidden_state = outputs.last_hidden_state\n",
|
579 |
+
"pooled_output = outputs.pooler_output # pooled CLS states"
|
580 |
+
]
|
581 |
+
},
|
582 |
+
{
|
583 |
+
"cell_type": "code",
|
584 |
+
"execution_count": 31,
|
585 |
+
"id": "bcf0a7b3-6cbb-492e-bc2c-42e3edbe6a0c",
|
586 |
+
"metadata": {},
|
587 |
+
"outputs": [
|
588 |
+
{
|
589 |
+
"data": {
|
590 |
+
"text/plain": [
|
591 |
+
"torch.Size([1, 768])"
|
592 |
+
]
|
593 |
+
},
|
594 |
+
"execution_count": 31,
|
595 |
+
"metadata": {},
|
596 |
+
"output_type": "execute_result"
|
597 |
+
}
|
598 |
+
],
|
599 |
+
"source": [
|
600 |
+
"pooled_output.shape"
|
601 |
+
]
|
602 |
+
},
|
603 |
+
{
|
604 |
+
"cell_type": "code",
|
605 |
+
"execution_count": 10,
|
606 |
+
"id": "67240294-c7a0-4e94-a8c1-86bfe1b21977",
|
607 |
+
"metadata": {},
|
608 |
+
"outputs": [],
|
609 |
+
"source": [
|
610 |
+
"from transformers import CLIPPreTrainedModel\n",
|
611 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
612 |
+
"from typing import Optional, Union, Tuple"
|
613 |
+
]
|
614 |
+
},
|
615 |
+
{
|
616 |
+
"cell_type": "code",
|
617 |
+
"execution_count": 54,
|
618 |
+
"id": "cc9b20db-7f84-44c3-9c78-e84164ccc192",
|
619 |
+
"metadata": {},
|
620 |
+
"outputs": [],
|
621 |
+
"source": [
|
622 |
+
"class VisionLanguageConnector(nn.Module):\n",
|
623 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
624 |
+
" super().__init__()\n",
|
625 |
+
" self.mlp = nn.Sequential(\n",
|
626 |
+
" nn.Linear(hidden_size, hidden_size, bias=False),\n",
|
627 |
+
" nn.GELU(),\n",
|
628 |
+
" nn.Linear(hidden_size, projection_dim, bias=False)\n",
|
629 |
+
" )\n",
|
630 |
+
"\n",
|
631 |
+
" def forward(self, x):\n",
|
632 |
+
" return self.mlp(x)\n",
|
633 |
+
" \n",
|
634 |
+
"class ClipWithProjection(CLIPPreTrainedModel):\n",
|
635 |
+
" config_class = CLIPVisionConfig\n",
|
636 |
+
" main_input_name = \"pixel_values\"\n",
|
637 |
+
"\n",
|
638 |
+
" def __init__(self, config: CLIPVisionConfig):\n",
|
639 |
+
" super().__init__(config)\n",
|
640 |
+
"\n",
|
641 |
+
" self.vision_model = CLIPVisionTransformer(config)\n",
|
642 |
+
" self.vision_model.\n",
|
643 |
+
" self.vision_language_connector = VisionLanguageConnector(config.hidden_size, config.projection_dim)\n",
|
644 |
+
"\n",
|
645 |
+
" # Initialize weights and apply final processing\n",
|
646 |
+
" self.post_init()\n",
|
647 |
+
"\n",
|
648 |
+
" def forward(\n",
|
649 |
+
" self,\n",
|
650 |
+
" pixel_values: Optional[torch.FloatTensor] = None,\n",
|
651 |
+
" output_attentions: Optional[bool] = None,\n",
|
652 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
653 |
+
" return_dict: Optional[bool] = None,\n",
|
654 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
655 |
+
" return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n",
|
656 |
+
"\n",
|
657 |
+
" vision_outputs = self.vision_model(\n",
|
658 |
+
" pixel_values=pixel_values,\n",
|
659 |
+
" output_attentions=output_attentions,\n",
|
660 |
+
" output_hidden_states=output_hidden_states,\n",
|
661 |
+
" return_dict=return_dict,\n",
|
662 |
+
" )\n",
|
663 |
+
"\n",
|
664 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
665 |
+
"\n",
|
666 |
+
" image_embeds = self.vision_language_connector(pooled_output)\n",
|
667 |
+
"\n",
|
668 |
+
" if not return_dict:\n",
|
669 |
+
" outputs = (image_embeds, vision_outputs[0]) + vision_outputs[2:]\n",
|
670 |
+
" return tuple(output for output in outputs if output is not None)\n",
|
671 |
+
"\n",
|
672 |
+
" return CLIPVisionModelOutput(\n",
|
673 |
+
" image_embeds=image_embeds,\n",
|
674 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
675 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
676 |
+
" attentions=vision_outputs.attentions,\n",
|
677 |
+
" )"
|
678 |
+
]
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"cell_type": "code",
|
682 |
+
"execution_count": 55,
|
683 |
+
"id": "a4892ab8-39d2-41c9-ad2a-04711c22b95f",
|
684 |
+
"metadata": {
|
685 |
+
"collapsed": true,
|
686 |
+
"jupyter": {
|
687 |
+
"outputs_hidden": true
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"outputs": [
|
691 |
+
{
|
692 |
+
"name": "stderr",
|
693 |
+
"output_type": "stream",
|
694 |
+
"text": [
|
695 |
+
"loading configuration file config.json from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/config.json\n",
|
696 |
+
"Model config CLIPVisionConfig {\n",
|
697 |
+
" \"attention_dropout\": 0.0,\n",
|
698 |
+
" \"dropout\": 0.0,\n",
|
699 |
+
" \"hidden_act\": \"quick_gelu\",\n",
|
700 |
+
" \"hidden_size\": 768,\n",
|
701 |
+
" \"image_size\": 224,\n",
|
702 |
+
" \"initializer_factor\": 1.0,\n",
|
703 |
+
" \"initializer_range\": 0.02,\n",
|
704 |
+
" \"intermediate_size\": 3072,\n",
|
705 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
706 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
707 |
+
" \"num_attention_heads\": 12,\n",
|
708 |
+
" \"num_channels\": 3,\n",
|
709 |
+
" \"num_hidden_layers\": 12,\n",
|
710 |
+
" \"patch_size\": 32,\n",
|
711 |
+
" \"projection_dim\": 512,\n",
|
712 |
+
" \"transformers_version\": \"4.36.2\"\n",
|
713 |
+
"}\n",
|
714 |
+
"\n",
|
715 |
+
"loading weights file pytorch_model.bin from cache at /home/ravi.naik/.cache/huggingface/hub/models--openai--clip-vit-base-patch32/snapshots/e6a30b603a447e251fdaca1c3056b2a16cdfebeb/pytorch_model.bin\n",
|
716 |
+
"Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing ClipWithProjection: ['text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.4.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.2.self_attn.q_proj.bias', 'text_model.encoder.layers.4.self_attn.v_proj.bias', 'text_model.encoder.layers.6.self_attn.k_proj.bias', 'text_model.encoder.layers.9.self_attn.k_proj.weight', 'text_model.encoder.layers.11.self_attn.q_proj.bias', 'text_model.encoder.layers.3.self_attn.out_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.bias', 'text_model.encoder.layers.1.self_attn.q_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.11.mlp.fc2.bias', 'text_model.encoder.layers.10.layer_norm1.bias', 'text_model.encoder.layers.7.self_attn.q_proj.bias', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.5.self_attn.q_proj.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'logit_scale', 'text_model.encoder.layers.8.mlp.fc1.bias', 'text_model.encoder.layers.6.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.7.self_attn.out_proj.bias', 'text_model.embeddings.token_embedding.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.weight', 'text_model.encoder.layers.1.mlp.fc1.weight', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.6.self_attn.k_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.weight', 'text_model.encoder.layers.2.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.v_proj.weight', 'text_model.encoder.layers.7.self_attn.q_proj.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.3.self_attn.k_proj.weight', 'text_model.encoder.layers.7.mlp.fc2.bias', 'text_model.encoder.layers.10.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.q_proj.weight', 'text_model.encoder.layers.1.self_attn.k_proj.weight', 'text_model.encoder.layers.4.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.q_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.weight', 'visual_projection.weight', 'text_model.encoder.layers.6.layer_norm2.bias', 'text_model.encoder.layers.6.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.10.mlp.fc2.bias', 'text_model.encoder.layers.10.mlp.fc1.weight', 'text_model.encoder.layers.6.self_attn.out_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.6.self_attn.q_proj.bias', 'text_model.encoder.layers.5.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 'text_model.encoder.layers.0.self_attn.out_proj.weight', 'text_model.encoder.layers.11.mlp.fc2.weight', 'text_model.encoder.layers.7.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.8.self_attn.v_proj.weight', 'text_model.encoder.layers.3.layer_norm1.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.9.self_attn.q_proj.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.10.self_attn.k_proj.weight', 'text_model.encoder.layers.7.self_attn.v_proj.bias', 'text_model.encoder.layers.7.mlp.fc1.bias', 'text_model.encoder.layers.11.mlp.fc1.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.3.mlp.fc2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.weight', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.11.self_attn.out_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.0.self_attn.k_proj.weight', 'text_model.encoder.layers.9.layer_norm1.bias', 'text_model.final_layer_norm.weight', 'text_model.encoder.layers.3.layer_norm1.weight', 'text_model.encoder.layers.4.mlp.fc1.bias', 'text_model.encoder.layers.1.layer_norm1.weight', 'text_model.encoder.layers.10.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.k_proj.bias', 'text_model.encoder.layers.8.mlp.fc2.bias', 'text_model.encoder.layers.5.mlp.fc2.bias', 'text_model.encoder.layers.6.self_attn.q_proj.weight', 'text_model.encoder.layers.5.self_attn.out_proj.bias', 'text_model.encoder.layers.9.mlp.fc2.bias', 'text_model.encoder.layers.5.layer_norm2.weight', 'text_model.encoder.layers.2.mlp.fc2.weight', 'text_model.encoder.layers.3.self_attn.out_proj.weight', 'text_model.encoder.layers.6.mlp.fc2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.weight', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.7.mlp.fc2.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.11.self_attn.v_proj.bias', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.4.layer_norm2.bias', 'text_model.encoder.layers.8.self_attn.q_proj.bias', 'text_model.embeddings.position_ids', 'text_model.encoder.layers.10.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm2.weight', 'text_model.encoder.layers.10.self_attn.q_proj.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.8.layer_norm1.bias', 'text_model.encoder.layers.2.self_attn.k_proj.weight', 'text_model.encoder.layers.5.mlp.fc1.bias', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.v_proj.weight', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.out_proj.bias', 'text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.6.mlp.fc1.weight', 'text_model.encoder.layers.6.mlp.fc1.bias', 'text_model.encoder.layers.3.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.9.self_attn.out_proj.weight', 'text_model.encoder.layers.3.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.q_proj.bias', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.4.layer_norm2.weight', 'text_model.encoder.layers.4.mlp.fc2.bias', 'text_model.encoder.layers.9.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.0.self_attn.k_proj.bias', 'text_model.encoder.layers.8.self_attn.k_proj.bias', 'text_model.encoder.layers.9.layer_norm2.weight', 'text_model.encoder.layers.4.self_attn.k_proj.bias', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.0.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.3.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc1.weight', 'text_model.encoder.layers.4.self_attn.q_proj.bias', 'text_model.encoder.layers.8.layer_norm1.weight', 'text_model.encoder.layers.2.self_attn.v_proj.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.11.mlp.fc1.bias', 'text_model.encoder.layers.6.mlp.fc2.bias', 'text_model.encoder.layers.1.mlp.fc1.bias', 'text_model.encoder.layers.2.self_attn.v_proj.bias', 'text_model.encoder.layers.5.mlp.fc2.weight', 'text_model.encoder.layers.8.self_attn.v_proj.bias', 'text_model.encoder.layers.10.self_attn.out_proj.bias', 'text_model.encoder.layers.5.layer_norm1.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.10.self_attn.q_proj.bias', 'text_model.encoder.layers.2.layer_norm2.bias', 'text_model.encoder.layers.7.layer_norm1.bias', 'text_model.encoder.layers.4.mlp.fc2.weight', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.3.mlp.fc1.weight', 'text_model.encoder.layers.5.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.1.self_attn.k_proj.bias', 'text_model.encoder.layers.7.self_attn.out_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.v_proj.weight', 'text_model.encoder.layers.1.layer_norm1.bias', 'text_model.encoder.layers.1.mlp.fc2.weight', 'text_model.encoder.layers.9.layer_norm2.bias', 'text_model.encoder.layers.9.self_attn.k_proj.bias', 'text_model.encoder.layers.11.layer_norm1.weight', 'text_model.encoder.layers.8.self_attn.out_proj.weight', 'text_model.encoder.layers.0.layer_norm1.bias', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.encoder.layers.0.mlp.fc1.bias', 'text_model.encoder.layers.0.layer_norm2.bias', 'text_model.encoder.layers.3.self_attn.k_proj.bias', 'text_model.encoder.layers.5.layer_norm1.weight', 'text_model.encoder.layers.3.layer_norm2.weight', 'text_model.encoder.layers.1.self_attn.q_proj.bias', 'text_model.encoder.layers.2.self_attn.out_proj.bias', 'text_model.encoder.layers.3.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.final_layer_norm.bias', 'text_model.encoder.layers.6.self_attn.v_proj.weight', 'text_model.encoder.layers.0.mlp.fc2.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.10.mlp.fc1.bias', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.6.self_attn.out_proj.bias', 'text_model.encoder.layers.2.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.weight', 'text_projection.weight', 'text_model.encoder.layers.11.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight']\n",
|
717 |
+
"- This IS expected if you are initializing ClipWithProjection from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
718 |
+
"- This IS NOT expected if you are initializing ClipWithProjection from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
719 |
+
"Some weights of ClipWithProjection were not initialized from the model checkpoint at openai/clip-vit-base-patch32 and are newly initialized: ['vision_language_connector.mlp.2.weight', 'vision_language_connector.mlp.0.weight']\n",
|
720 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
721 |
+
]
|
722 |
+
}
|
723 |
+
],
|
724 |
+
"source": [
|
725 |
+
"model = ClipWithProjection.from_pretrained(\"openai/clip-vit-base-patch32\")"
|
726 |
+
]
|
727 |
+
},
|
728 |
+
{
|
729 |
+
"cell_type": "code",
|
730 |
+
"execution_count": 56,
|
731 |
+
"id": "588ef914-5be9-49e1-b68d-b899e0e74edd",
|
732 |
+
"metadata": {},
|
733 |
+
"outputs": [
|
734 |
+
{
|
735 |
+
"data": {
|
736 |
+
"text/plain": [
|
737 |
+
"768"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
"execution_count": 56,
|
741 |
+
"metadata": {},
|
742 |
+
"output_type": "execute_result"
|
743 |
+
}
|
744 |
+
],
|
745 |
+
"source": [
|
746 |
+
"model.config.hidden_size"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": 57,
|
752 |
+
"id": "05d95b9e-9831-4415-860e-94793e29d210",
|
753 |
+
"metadata": {},
|
754 |
+
"outputs": [],
|
755 |
+
"source": [
|
756 |
+
"outputs = model(**inputs)"
|
757 |
+
]
|
758 |
+
},
|
759 |
+
{
|
760 |
+
"cell_type": "code",
|
761 |
+
"execution_count": 61,
|
762 |
+
"id": "185b1bff-6ffe-4cce-9255-ee7629feba54",
|
763 |
+
"metadata": {},
|
764 |
+
"outputs": [
|
765 |
+
{
|
766 |
+
"data": {
|
767 |
+
"text/plain": [
|
768 |
+
"torch.Size([1, 512])"
|
769 |
+
]
|
770 |
+
},
|
771 |
+
"execution_count": 61,
|
772 |
+
"metadata": {},
|
773 |
+
"output_type": "execute_result"
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"source": [
|
777 |
+
"outputs[0].shape"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "code",
|
782 |
+
"execution_count": null,
|
783 |
+
"id": "04414a35-c7b3-4986-a79e-1d363916caa4",
|
784 |
+
"metadata": {},
|
785 |
+
"outputs": [],
|
786 |
+
"source": []
|
787 |
+
},
|
788 |
+
{
|
789 |
+
"cell_type": "code",
|
790 |
+
"execution_count": 1,
|
791 |
+
"id": "485dbbcb-06df-4926-b257-dfd1a4081d44",
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [
|
794 |
+
{
|
795 |
+
"ename": "NameError",
|
796 |
+
"evalue": "name 'outputs' is not defined",
|
797 |
+
"output_type": "error",
|
798 |
+
"traceback": [
|
799 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
800 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
801 |
+
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43moutputs\u001b[49m[\u001b[38;5;241m0\u001b[39m]\n",
|
802 |
+
"\u001b[0;31mNameError\u001b[0m: name 'outputs' is not defined"
|
803 |
+
]
|
804 |
+
}
|
805 |
+
],
|
806 |
+
"source": [
|
807 |
+
"outputs[0]"
|
808 |
+
]
|
809 |
+
},
|
810 |
+
{
|
811 |
+
"cell_type": "code",
|
812 |
+
"execution_count": null,
|
813 |
+
"id": "f983313c-8e0f-4805-af14-25bb69afd04c",
|
814 |
+
"metadata": {},
|
815 |
+
"outputs": [],
|
816 |
+
"source": []
|
817 |
+
}
|
818 |
+
],
|
819 |
+
"metadata": {
|
820 |
+
"kernelspec": {
|
821 |
+
"display_name": "Python 3 (ipykernel)",
|
822 |
+
"language": "python",
|
823 |
+
"name": "python3"
|
824 |
+
},
|
825 |
+
"language_info": {
|
826 |
+
"codemirror_mode": {
|
827 |
+
"name": "ipython",
|
828 |
+
"version": 3
|
829 |
+
},
|
830 |
+
"file_extension": ".py",
|
831 |
+
"mimetype": "text/x-python",
|
832 |
+
"name": "python",
|
833 |
+
"nbconvert_exporter": "python",
|
834 |
+
"pygments_lexer": "ipython3",
|
835 |
+
"version": "3.10.12"
|
836 |
+
}
|
837 |
+
},
|
838 |
+
"nbformat": 4,
|
839 |
+
"nbformat_minor": 5
|
840 |
+
}
|
Experiments/eval.ipynb
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 4,
|
6 |
+
"id": "215cfd2f-62b0-4a86-a407-777a1d32597f",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"[2024-01-24 15:18:49,948] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from PIL import Image\n",
|
19 |
+
"import requests\n",
|
20 |
+
"\n",
|
21 |
+
"import torch\n",
|
22 |
+
"from torch import nn\n",
|
23 |
+
"from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
|
24 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
25 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
26 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
|
27 |
+
]
|
28 |
+
},
|
29 |
+
{
|
30 |
+
"cell_type": "code",
|
31 |
+
"execution_count": 5,
|
32 |
+
"id": "2244e8f3-fcc7-4309-9d4d-fea557f89f79",
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [],
|
35 |
+
"source": [
|
36 |
+
"from llava_phi import LlavaPhiForCausalLM"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"execution_count": 3,
|
42 |
+
"id": "587883e1-3419-4b14-b16b-38fabbc8bfaa",
|
43 |
+
"metadata": {},
|
44 |
+
"outputs": [],
|
45 |
+
"source": [
|
46 |
+
"# model = LlavaPhiForCausalLM.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"execution_count": 4,
|
52 |
+
"id": "0e27a7db-e2ab-4d65-b21d-497222e318ad",
|
53 |
+
"metadata": {},
|
54 |
+
"outputs": [],
|
55 |
+
"source": [
|
56 |
+
"# processor = AutoProcessor.from_pretrained(\"./llava-phi/checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\")"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 5,
|
62 |
+
"id": "663efdd8-ea21-4231-a2ae-bcc0fb47b46a",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"# prompt = \"<image>\\nUSER: What's the content of the image?\\nASSISTANT:\"\n",
|
67 |
+
"# url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
|
68 |
+
"# image = Image.open(requests.get(url, stream=True).raw)"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 6,
|
74 |
+
"id": "f622609f-f6a7-4ec1-ac35-c1d33d9436ca",
|
75 |
+
"metadata": {},
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"# # Generate\n",
|
79 |
+
"# generate_ids = model.generate(**inputs, max_length=30)\n",
|
80 |
+
"# processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
|
81 |
+
]
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"cell_type": "code",
|
85 |
+
"execution_count": 6,
|
86 |
+
"id": "45f5ba72-2e41-4ccc-84c1-97d542ebee63",
|
87 |
+
"metadata": {},
|
88 |
+
"outputs": [],
|
89 |
+
"source": [
|
90 |
+
"from llava_phi.model.builder import load_pretrained_model\n",
|
91 |
+
"from llava_phi.mm_utils import tokenizer_image_token, get_model_name_from_path\n",
|
92 |
+
"from llava_phi.utils import disable_torch_init\n",
|
93 |
+
"from llava_phi.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
|
94 |
+
"from llava_phi.conversation import conv_templates, SeparatorStyle"
|
95 |
+
]
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "code",
|
99 |
+
"execution_count": 11,
|
100 |
+
"id": "b98ac5d3-5503-4430-81d1-19a4f8d6bd75",
|
101 |
+
"metadata": {},
|
102 |
+
"outputs": [],
|
103 |
+
"source": [
|
104 |
+
"model_path = \"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\"\n",
|
105 |
+
"model_name = get_model_name_from_path(model_path)"
|
106 |
+
]
|
107 |
+
},
|
108 |
+
{
|
109 |
+
"cell_type": "code",
|
110 |
+
"execution_count": 12,
|
111 |
+
"id": "42fd5721-75a7-475b-bd30-5ee23aeaac64",
|
112 |
+
"metadata": {},
|
113 |
+
"outputs": [
|
114 |
+
{
|
115 |
+
"data": {
|
116 |
+
"text/plain": [
|
117 |
+
"'llavaPhi-v0-3b-finetune_checkpoint-4000'"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
"execution_count": 12,
|
121 |
+
"metadata": {},
|
122 |
+
"output_type": "execute_result"
|
123 |
+
}
|
124 |
+
],
|
125 |
+
"source": [
|
126 |
+
"model_name"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": 13,
|
132 |
+
"id": "8c2076b5-3bfc-48fd-917b-5dfd06fc532f",
|
133 |
+
"metadata": {},
|
134 |
+
"outputs": [
|
135 |
+
{
|
136 |
+
"name": "stderr",
|
137 |
+
"output_type": "stream",
|
138 |
+
"text": [
|
139 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"name": "stdout",
|
144 |
+
"output_type": "stream",
|
145 |
+
"text": [
|
146 |
+
"load llaVA-Phi MLLM!!!\n"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"name": "stderr",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
154 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"data": {
|
159 |
+
"application/vnd.jupyter.widget-view+json": {
|
160 |
+
"model_id": "20b86f2c01744081b537620c8780f12e",
|
161 |
+
"version_major": 2,
|
162 |
+
"version_minor": 0
|
163 |
+
},
|
164 |
+
"text/plain": [
|
165 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
166 |
+
]
|
167 |
+
},
|
168 |
+
"metadata": {},
|
169 |
+
"output_type": "display_data"
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"name": "stdout",
|
173 |
+
"output_type": "stream",
|
174 |
+
"text": [
|
175 |
+
"{'device_map': 'cuda'}\n"
|
176 |
+
]
|
177 |
+
}
|
178 |
+
],
|
179 |
+
"source": [
|
180 |
+
"tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": 14,
|
186 |
+
"id": "4e46221e-0907-453e-8126-76199828493e",
|
187 |
+
"metadata": {},
|
188 |
+
"outputs": [],
|
189 |
+
"source": [
|
190 |
+
"qs = \"What's the content of the image?\"\n",
|
191 |
+
"qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + qs"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"execution_count": 15,
|
197 |
+
"id": "07355444-0eb8-4d4d-ad50-48b91c969664",
|
198 |
+
"metadata": {},
|
199 |
+
"outputs": [],
|
200 |
+
"source": [
|
201 |
+
"conv = conv_templates[\"default\"].copy()\n",
|
202 |
+
"conv.append_message(conv.roles[0], qs)\n",
|
203 |
+
"conv.append_message(conv.roles[1], None)\n",
|
204 |
+
"prompt = conv.get_prompt()"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "code",
|
209 |
+
"execution_count": 16,
|
210 |
+
"id": "ccb5674f-aff8-456e-b61b-1d167864f1a6",
|
211 |
+
"metadata": {},
|
212 |
+
"outputs": [
|
213 |
+
{
|
214 |
+
"data": {
|
215 |
+
"text/plain": [
|
216 |
+
"\"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <im_start><image><im_end>\\nWhat's the content of the image? ASSISTANT:\""
|
217 |
+
]
|
218 |
+
},
|
219 |
+
"execution_count": 16,
|
220 |
+
"metadata": {},
|
221 |
+
"output_type": "execute_result"
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"source": [
|
225 |
+
"prompt"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": 17,
|
231 |
+
"id": "a89cc181-2214-4844-b966-164a41744e54",
|
232 |
+
"metadata": {},
|
233 |
+
"outputs": [],
|
234 |
+
"source": [
|
235 |
+
"url = \"https://www.ilankelman.org/stopsigns/australia.jpg\"\n",
|
236 |
+
"image = Image.open(requests.get(url, stream=True).raw)\n",
|
237 |
+
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
|
238 |
+
"\n",
|
239 |
+
"input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
|
240 |
+
"\n",
|
241 |
+
"stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 25,
|
247 |
+
"id": "0d519851-64d4-4cf5-b2eb-19474f9aa260",
|
248 |
+
"metadata": {},
|
249 |
+
"outputs": [
|
250 |
+
{
|
251 |
+
"data": {
|
252 |
+
"text/plain": [
|
253 |
+
"torch.Size([1, 55])"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
"execution_count": 25,
|
257 |
+
"metadata": {},
|
258 |
+
"output_type": "execute_result"
|
259 |
+
}
|
260 |
+
],
|
261 |
+
"source": [
|
262 |
+
"input_ids.shape"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"execution_count": 24,
|
268 |
+
"id": "1694ff36-f214-4ed3-b2f3-d3dbd0a1a25b",
|
269 |
+
"metadata": {},
|
270 |
+
"outputs": [
|
271 |
+
{
|
272 |
+
"name": "stderr",
|
273 |
+
"output_type": "stream",
|
274 |
+
"text": [
|
275 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
276 |
+
]
|
277 |
+
}
|
278 |
+
],
|
279 |
+
"source": [
|
280 |
+
"from datasets import load_dataset\n",
|
281 |
+
"audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
282 |
+
"audio = audio_ds[0][\"audio\"]\n",
|
283 |
+
"\n",
|
284 |
+
"whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
|
285 |
+
"audio_embed = whisper_w_proj(audio)[\"input_ids\"]"
|
286 |
+
]
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"execution_count": 28,
|
291 |
+
"id": "9c4a9fae-d6ed-4fc2-ba02-97df64cddd93",
|
292 |
+
"metadata": {},
|
293 |
+
"outputs": [
|
294 |
+
{
|
295 |
+
"data": {
|
296 |
+
"text/plain": [
|
297 |
+
"(torch.Size([1, 33]), device(type='cpu'))"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
"execution_count": 28,
|
301 |
+
"metadata": {},
|
302 |
+
"output_type": "execute_result"
|
303 |
+
}
|
304 |
+
],
|
305 |
+
"source": [
|
306 |
+
"audio_embed.shape, audio_embed.device"
|
307 |
+
]
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"cell_type": "code",
|
311 |
+
"execution_count": 29,
|
312 |
+
"id": "c3fffe29-98fb-4f4b-ac51-4bdda9e46752",
|
313 |
+
"metadata": {},
|
314 |
+
"outputs": [],
|
315 |
+
"source": [
|
316 |
+
"input_ids = torch.concat([input_ids, audio_embed.to(\"cuda:0\")], dim=1)"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
{
|
320 |
+
"cell_type": "code",
|
321 |
+
"execution_count": 30,
|
322 |
+
"id": "5dee1ec8-2db2-4f65-99e8-d34bd2735c9c",
|
323 |
+
"metadata": {},
|
324 |
+
"outputs": [
|
325 |
+
{
|
326 |
+
"data": {
|
327 |
+
"text/plain": [
|
328 |
+
"torch.Size([1, 88])"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
"execution_count": 30,
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "execute_result"
|
334 |
+
}
|
335 |
+
],
|
336 |
+
"source": [
|
337 |
+
"input_ids.shape"
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": 31,
|
343 |
+
"id": "96033b43-4f57-4f0c-bcf7-37b57ca02e47",
|
344 |
+
"metadata": {},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"with torch.inference_mode():\n",
|
348 |
+
" output_ids = model.generate(\n",
|
349 |
+
" input_ids,\n",
|
350 |
+
" images=image_tensor,\n",
|
351 |
+
" do_sample=True,\n",
|
352 |
+
" temperature=0.2,\n",
|
353 |
+
" max_new_tokens=1024,\n",
|
354 |
+
" eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
|
355 |
+
" pad_token_id=tokenizer.eos_token_id, # Pad token\n",
|
356 |
+
" use_cache=True,\n",
|
357 |
+
" )"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
{
|
361 |
+
"cell_type": "code",
|
362 |
+
"execution_count": 32,
|
363 |
+
"id": "741e8da5-0d18-4c11-b559-76054ce4ca3a",
|
364 |
+
"metadata": {},
|
365 |
+
"outputs": [
|
366 |
+
{
|
367 |
+
"name": "stdout",
|
368 |
+
"output_type": "stream",
|
369 |
+
"text": [
|
370 |
+
"is a Japanese character from the story of Jesus, who is a Chinese monk who is also known for his teachings. The story is based on the story of the story of Jesus Christ, and it is a representation of the story of Jesus and the story of Jesus Christ.\n"
|
371 |
+
]
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"source": [
|
375 |
+
"input_token_len = input_ids.shape[1]\n",
|
376 |
+
"n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
|
377 |
+
"if n_diff_input_output > 0:\n",
|
378 |
+
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
|
379 |
+
"outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
|
380 |
+
"outputs = outputs.strip()\n",
|
381 |
+
"if outputs.endswith(stop_str):\n",
|
382 |
+
" outputs = outputs[:-len(stop_str)]\n",
|
383 |
+
"outputs = outputs.strip()\n",
|
384 |
+
"print(outputs)"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "code",
|
389 |
+
"execution_count": 20,
|
390 |
+
"id": "69d494d4-d768-4645-b4d6-5c455791b50d",
|
391 |
+
"metadata": {},
|
392 |
+
"outputs": [],
|
393 |
+
"source": [
|
394 |
+
"# image"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": null,
|
400 |
+
"id": "8a340856-a13f-4b18-9911-126a4ba37816",
|
401 |
+
"metadata": {},
|
402 |
+
"outputs": [],
|
403 |
+
"source": []
|
404 |
+
},
|
405 |
+
{
|
406 |
+
"cell_type": "code",
|
407 |
+
"execution_count": null,
|
408 |
+
"id": "3c56fdea-c7a1-4e67-9832-e2ed077d8704",
|
409 |
+
"metadata": {},
|
410 |
+
"outputs": [],
|
411 |
+
"source": []
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"execution_count": 52,
|
416 |
+
"id": "89e84d39-8ed8-45db-ae82-27c156ee6dd1",
|
417 |
+
"metadata": {},
|
418 |
+
"outputs": [],
|
419 |
+
"source": [
|
420 |
+
"class AudioLanguageConnector:\n",
|
421 |
+
" def __init__(self, projection_dim):\n",
|
422 |
+
" model_name = \"microsoft/phi-2\"\n",
|
423 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
424 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
425 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
426 |
+
"\n",
|
427 |
+
" def __call__(self, text):\n",
|
428 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
429 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
430 |
+
" return tokens\n",
|
431 |
+
" \n",
|
432 |
+
"\n",
|
433 |
+
"class WhisperWithProjection:\n",
|
434 |
+
" def __init__(self, projection_dim, device):\n",
|
435 |
+
" self.device = device\n",
|
436 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
437 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\", device_map=device)\n",
|
438 |
+
" self.model.config.forced_decoder_ids = None\n",
|
439 |
+
" self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
|
440 |
+
" \n",
|
441 |
+
" def __call__(self, audio):\n",
|
442 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
443 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
444 |
+
" return_tensors=\"pt\").input_features\n",
|
445 |
+
" # generate token ids\n",
|
446 |
+
" predicted_ids = self.model.generate(input_features.to(self.device))\n",
|
447 |
+
" # decode token ids to text \n",
|
448 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
449 |
+
"\n",
|
450 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
451 |
+
" return audio_embeddings.to(self.device)"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"execution_count": 53,
|
457 |
+
"id": "75e24be0-b236-4047-83ef-5c344e262476",
|
458 |
+
"metadata": {},
|
459 |
+
"outputs": [],
|
460 |
+
"source": [
|
461 |
+
"class MultiModalPhi2:\n",
|
462 |
+
" def __init__(self, model_path=\"checkpoints/llavaPhi-v0-3b-finetune/checkpoint-4000\",\n",
|
463 |
+
" temperature=0.2,\n",
|
464 |
+
" max_new_tokens=1024,\n",
|
465 |
+
" device=\"cuda\"):\n",
|
466 |
+
" self.temperature = temperature\n",
|
467 |
+
" self.max_new_tokens = max_new_tokens\n",
|
468 |
+
" self.device = device\n",
|
469 |
+
" model_name = get_model_name_from_path(model_path)\n",
|
470 |
+
" self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(model_path, None, model_name, device_map=device)\n",
|
471 |
+
" self.whisper_w_proj = WhisperWithProjection(projection_dim=512, device=device)\n",
|
472 |
+
" \n",
|
473 |
+
" \n",
|
474 |
+
" def __call__(self, text, audio, image):\n",
|
475 |
+
" qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + text\n",
|
476 |
+
" conv = conv_templates[\"default\"].copy()\n",
|
477 |
+
" conv.append_message(conv.roles[0], qs)\n",
|
478 |
+
" conv.append_message(conv.roles[1], None)\n",
|
479 |
+
" prompt = conv.get_prompt()\n",
|
480 |
+
"\n",
|
481 |
+
" image_tensor = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()\n",
|
482 |
+
" \n",
|
483 |
+
" input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
|
484 |
+
"\n",
|
485 |
+
" audio_embed = self.whisper_w_proj(audio)[\"input_ids\"]\n",
|
486 |
+
" \n",
|
487 |
+
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
|
488 |
+
"\n",
|
489 |
+
" input_ids = torch.concat([input_ids, audio_embed], dim=1)\n",
|
490 |
+
"\n",
|
491 |
+
" with torch.inference_mode():\n",
|
492 |
+
" output_ids = self.model.generate(\n",
|
493 |
+
" input_ids,\n",
|
494 |
+
" images=image_tensor,\n",
|
495 |
+
" do_sample=True,\n",
|
496 |
+
" temperature=self.temperature,\n",
|
497 |
+
" max_new_tokens=self.max_new_tokens,\n",
|
498 |
+
" eos_token_id=tokenizer.eos_token_id, # End of sequence token\n",
|
499 |
+
" pad_token_id=tokenizer.eos_token_id, # Pad token\n",
|
500 |
+
" use_cache=True,\n",
|
501 |
+
" )\n",
|
502 |
+
"\n",
|
503 |
+
" input_token_len = input_ids.shape[1]\n",
|
504 |
+
" n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
|
505 |
+
" if n_diff_input_output > 0:\n",
|
506 |
+
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
|
507 |
+
" outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
|
508 |
+
" outputs = outputs.strip()\n",
|
509 |
+
" if outputs.endswith(stop_str):\n",
|
510 |
+
" outputs = outputs[:-len(stop_str)]\n",
|
511 |
+
" outputs = outputs.strip()\n",
|
512 |
+
" return outputs"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"cell_type": "code",
|
517 |
+
"execution_count": 54,
|
518 |
+
"id": "4efdbad4-d88a-4477-a3a0-f5591cd0b172",
|
519 |
+
"metadata": {},
|
520 |
+
"outputs": [
|
521 |
+
{
|
522 |
+
"name": "stderr",
|
523 |
+
"output_type": "stream",
|
524 |
+
"text": [
|
525 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
|
526 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
527 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
528 |
+
]
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"name": "stdout",
|
532 |
+
"output_type": "stream",
|
533 |
+
"text": [
|
534 |
+
"load llaVA-Phi MLLM!!!\n"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"data": {
|
539 |
+
"application/vnd.jupyter.widget-view+json": {
|
540 |
+
"model_id": "492c17cf54f34d4d9e4f288fc9e72e79",
|
541 |
+
"version_major": 2,
|
542 |
+
"version_minor": 0
|
543 |
+
},
|
544 |
+
"text/plain": [
|
545 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
"metadata": {},
|
549 |
+
"output_type": "display_data"
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"name": "stdout",
|
553 |
+
"output_type": "stream",
|
554 |
+
"text": [
|
555 |
+
"{'device_map': 'cuda'}\n"
|
556 |
+
]
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"name": "stderr",
|
560 |
+
"output_type": "stream",
|
561 |
+
"text": [
|
562 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
563 |
+
]
|
564 |
+
}
|
565 |
+
],
|
566 |
+
"source": [
|
567 |
+
"multimodal_phi2 = MultiModalPhi2()"
|
568 |
+
]
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"cell_type": "code",
|
572 |
+
"execution_count": 57,
|
573 |
+
"id": "9a6de0b0-a231-4d50-88e8-e40c6f7216c3",
|
574 |
+
"metadata": {},
|
575 |
+
"outputs": [],
|
576 |
+
"source": [
|
577 |
+
"text = \"tell me about the audio\""
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 58,
|
583 |
+
"id": "b4919948-6a75-4d19-ba95-9ba233a7d3d9",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [
|
586 |
+
{
|
587 |
+
"data": {
|
588 |
+
"text/plain": [
|
589 |
+
"'is a popular Japanese drama series featuring a man in a red and white costume, who is dressed as Santa Claus, is walking down the street. The scene takes place in a busy city environment, with people walking and standing on the sidewalk, likely enjoying the festive atmosphere and the festive atmosphere.'"
|
590 |
+
]
|
591 |
+
},
|
592 |
+
"execution_count": 58,
|
593 |
+
"metadata": {},
|
594 |
+
"output_type": "execute_result"
|
595 |
+
}
|
596 |
+
],
|
597 |
+
"source": [
|
598 |
+
"multimodal_phi2(text, audio, image)"
|
599 |
+
]
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"cell_type": "code",
|
603 |
+
"execution_count": null,
|
604 |
+
"id": "590f2d64-62ed-4e6f-b7c8-b0cf68aecaab",
|
605 |
+
"metadata": {},
|
606 |
+
"outputs": [],
|
607 |
+
"source": []
|
608 |
+
},
|
609 |
+
{
|
610 |
+
"cell_type": "code",
|
611 |
+
"execution_count": 64,
|
612 |
+
"id": "c921eb63-feb5-4fa9-993b-2faeb6dfe1db",
|
613 |
+
"metadata": {},
|
614 |
+
"outputs": [],
|
615 |
+
"source": [
|
616 |
+
"from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig, CLIPImageProcessor"
|
617 |
+
]
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"cell_type": "code",
|
621 |
+
"execution_count": 65,
|
622 |
+
"id": "b470a2c4-806a-435d-9fc2-f17448dbe5fc",
|
623 |
+
"metadata": {},
|
624 |
+
"outputs": [],
|
625 |
+
"source": [
|
626 |
+
"from llava_phi.model import LlavaPhiConfig"
|
627 |
+
]
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"cell_type": "code",
|
631 |
+
"execution_count": 66,
|
632 |
+
"id": "4f7bc91a-0a41-45e5-92a4-daa1e3eea0da",
|
633 |
+
"metadata": {},
|
634 |
+
"outputs": [
|
635 |
+
{
|
636 |
+
"name": "stderr",
|
637 |
+
"output_type": "stream",
|
638 |
+
"text": [
|
639 |
+
"The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
|
640 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
641 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
642 |
+
]
|
643 |
+
},
|
644 |
+
{
|
645 |
+
"data": {
|
646 |
+
"application/vnd.jupyter.widget-view+json": {
|
647 |
+
"model_id": "993bc3a38cb84de4a2e3a79a3448c4d6",
|
648 |
+
"version_major": 2,
|
649 |
+
"version_minor": 0
|
650 |
+
},
|
651 |
+
"text/plain": [
|
652 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
653 |
+
]
|
654 |
+
},
|
655 |
+
"metadata": {},
|
656 |
+
"output_type": "display_data"
|
657 |
+
}
|
658 |
+
],
|
659 |
+
"source": [
|
660 |
+
"device_map = \"cuda:0\"\n",
|
661 |
+
"load_8bit = False\n",
|
662 |
+
"load_4bit = False\n",
|
663 |
+
"kwargs = {\"device_map\": device_map}\n",
|
664 |
+
"if load_8bit:\n",
|
665 |
+
" kwargs['load_in_8bit'] = True\n",
|
666 |
+
"elif load_4bit:\n",
|
667 |
+
" kwargs['load_in_4bit'] = True\n",
|
668 |
+
" kwargs['quantization_config'] = BitsAndBytesConfig(\n",
|
669 |
+
" load_in_4bit=True,\n",
|
670 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
671 |
+
" bnb_4bit_use_double_quant=True,\n",
|
672 |
+
" bnb_4bit_quant_type='nf4'\n",
|
673 |
+
" )\n",
|
674 |
+
"config = LlavaPhiConfig.from_pretrained(model_path, trust_remote_code=True)\n",
|
675 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)\n",
|
676 |
+
"model = LlavaPhiForCausalLM.from_pretrained(\n",
|
677 |
+
" model_path, \n",
|
678 |
+
" config=config, \n",
|
679 |
+
" use_safetensors=True, \n",
|
680 |
+
" **kwargs).to(\"cuda\")\n",
|
681 |
+
"image_processor = CLIPImageProcessor.from_pretrained(model_path)\n",
|
682 |
+
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
|
683 |
+
"mm_use_im_patch_token = getattr(model.config, \"mm_use_im_patch_token\", True)\n",
|
684 |
+
"\n",
|
685 |
+
"# TODO: the tokenizer length of phi-2 is 50295, but the output class of lm_head is 51200\n",
|
686 |
+
"if mm_use_im_patch_token:\n",
|
687 |
+
" tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
|
688 |
+
"if mm_use_im_start_end:\n",
|
689 |
+
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
|
690 |
+
" \n",
|
691 |
+
"if hasattr(model.config, \"max_sequence_length\"):\n",
|
692 |
+
" context_len = model.config.max_sequence_length\n",
|
693 |
+
"else:\n",
|
694 |
+
" context_len = 2048"
|
695 |
+
]
|
696 |
+
},
|
697 |
+
{
|
698 |
+
"cell_type": "code",
|
699 |
+
"execution_count": 70,
|
700 |
+
"id": "99355837-a297-4a25-aeb3-1670af7e9251",
|
701 |
+
"metadata": {},
|
702 |
+
"outputs": [
|
703 |
+
{
|
704 |
+
"ename": "KeyboardInterrupt",
|
705 |
+
"evalue": "",
|
706 |
+
"output_type": "error",
|
707 |
+
"traceback": [
|
708 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
709 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
710 |
+
"Cell \u001b[0;32mIn[70], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave_pretrained\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mLlava-Phi-Checkpoint\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
711 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/transformers/modeling_utils.py:2376\u001b[0m, in \u001b[0;36mPreTrainedModel.save_pretrained\u001b[0;34m(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)\u001b[0m\n\u001b[1;32m 2372\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m shard_file, shard \u001b[38;5;129;01min\u001b[39;00m shards\u001b[38;5;241m.\u001b[39mitems():\n\u001b[1;32m 2373\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m safe_serialization:\n\u001b[1;32m 2374\u001b[0m \u001b[38;5;66;03m# At some point we will need to deal better with save_function (used for TPU and other distributed\u001b[39;00m\n\u001b[1;32m 2375\u001b[0m \u001b[38;5;66;03m# joyfulness), but for now this enough.\u001b[39;00m\n\u001b[0;32m-> 2376\u001b[0m \u001b[43msafe_save_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_directory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mshard_file\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mformat\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2377\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 2378\u001b[0m save_function(shard, os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(save_directory, shard_file))\n",
|
712 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/safetensors/torch.py:281\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(tensors, filename, metadata)\u001b[0m\n\u001b[1;32m 250\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msave_file\u001b[39m(\n\u001b[1;32m 251\u001b[0m tensors: Dict[\u001b[38;5;28mstr\u001b[39m, torch\u001b[38;5;241m.\u001b[39mTensor],\n\u001b[1;32m 252\u001b[0m filename: Union[\u001b[38;5;28mstr\u001b[39m, os\u001b[38;5;241m.\u001b[39mPathLike],\n\u001b[1;32m 253\u001b[0m metadata: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, \u001b[38;5;28mstr\u001b[39m]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 254\u001b[0m ):\n\u001b[1;32m 255\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 256\u001b[0m \u001b[38;5;124;03m Saves a dictionary of tensors into raw bytes in safetensors format.\u001b[39;00m\n\u001b[1;32m 257\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 279\u001b[0m \u001b[38;5;124;03m ```\u001b[39;00m\n\u001b[1;32m 280\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 281\u001b[0m \u001b[43mserialize_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_flatten\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmetadata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetadata\u001b[49m\u001b[43m)\u001b[49m\n",
|
713 |
+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
714 |
+
]
|
715 |
+
}
|
716 |
+
],
|
717 |
+
"source": [
|
718 |
+
"model.save_pretrained(\"Llava-Phi-Checkpoint\")"
|
719 |
+
]
|
720 |
+
},
|
721 |
+
{
|
722 |
+
"cell_type": "code",
|
723 |
+
"execution_count": null,
|
724 |
+
"id": "fa0bec34-a148-4340-a30c-6f09dd5e71ca",
|
725 |
+
"metadata": {},
|
726 |
+
"outputs": [],
|
727 |
+
"source": [
|
728 |
+
"model.push_to_hub(\"RaviNaik/Llava-Phi2\")"
|
729 |
+
]
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"cell_type": "code",
|
733 |
+
"execution_count": 73,
|
734 |
+
"id": "382f74b0-2967-408a-badc-a90918810d74",
|
735 |
+
"metadata": {},
|
736 |
+
"outputs": [
|
737 |
+
{
|
738 |
+
"data": {
|
739 |
+
"text/plain": [
|
740 |
+
"CommitInfo(commit_url='https://huggingface.co/RaviNaik/Llava-Phi2/commit/fa8f7240058241243f6bdc3d6ab44bb691f76e39', commit_message='Upload tokenizer', commit_description='', oid='fa8f7240058241243f6bdc3d6ab44bb691f76e39', pr_url=None, pr_revision=None, pr_num=None)"
|
741 |
+
]
|
742 |
+
},
|
743 |
+
"execution_count": 73,
|
744 |
+
"metadata": {},
|
745 |
+
"output_type": "execute_result"
|
746 |
+
}
|
747 |
+
],
|
748 |
+
"source": [
|
749 |
+
"tokenizer.push_to_hub(\"RaviNaik/Llava-Phi2\")"
|
750 |
+
]
|
751 |
+
},
|
752 |
+
{
|
753 |
+
"cell_type": "code",
|
754 |
+
"execution_count": null,
|
755 |
+
"id": "b851459b-d3ac-4fb8-99b6-17a648adc41f",
|
756 |
+
"metadata": {},
|
757 |
+
"outputs": [],
|
758 |
+
"source": []
|
759 |
+
}
|
760 |
+
],
|
761 |
+
"metadata": {
|
762 |
+
"kernelspec": {
|
763 |
+
"display_name": "Python 3 (ipykernel)",
|
764 |
+
"language": "python",
|
765 |
+
"name": "python3"
|
766 |
+
},
|
767 |
+
"language_info": {
|
768 |
+
"codemirror_mode": {
|
769 |
+
"name": "ipython",
|
770 |
+
"version": 3
|
771 |
+
},
|
772 |
+
"file_extension": ".py",
|
773 |
+
"mimetype": "text/x-python",
|
774 |
+
"name": "python",
|
775 |
+
"nbconvert_exporter": "python",
|
776 |
+
"pygments_lexer": "ipython3",
|
777 |
+
"version": "3.10.12"
|
778 |
+
}
|
779 |
+
},
|
780 |
+
"nbformat": 4,
|
781 |
+
"nbformat_minor": 5
|
782 |
+
}
|
Experiments/instruct_150k_data.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Experiments/instruct_data.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import Dataset, IterableDataset
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
# ChatML format
|
5 |
+
templates = {
|
6 |
+
"assistant": "<|im_start|>assistant\n{msg}<|im_end|>", # message by assistant
|
7 |
+
"user": "<|im_start|>user\n{msg}<|im_end|>" # message by user
|
8 |
+
}
|
9 |
+
|
10 |
+
ds = Dataset.from_json("llava_instruct_150k.json", split="train")
|
11 |
+
ds_stream = ds.to_iterable_dataset()
|
12 |
+
|
13 |
+
|
14 |
+
def get_image(image_path):
|
15 |
+
image_path = f"train2014/COCO_train2014_{image_path}"
|
16 |
+
img = Image.open(image_path)
|
17 |
+
return img
|
18 |
+
|
19 |
+
def get_chatml_text(conversations):
|
20 |
+
chatml_text = ""
|
21 |
+
for conversation in conversations:
|
22 |
+
role = conversation["from"]
|
23 |
+
role = "user" if role == "human" else "assistant"
|
24 |
+
content = conversation["value"]
|
25 |
+
|
26 |
+
formatted_text = templates[role].format(msg=content)
|
27 |
+
chatml_text += formatted_text + "\n"
|
28 |
+
return chatml_text
|
29 |
+
|
30 |
+
def instruct_data_generator():
|
31 |
+
for sample in ds_stream:
|
32 |
+
image_path = sample["image"]
|
33 |
+
conversations = sample["conversations"]
|
34 |
+
|
35 |
+
image = get_image(image_path)
|
36 |
+
text = get_chatml_text(conversations)
|
37 |
+
yield {"text": text, "image": image}
|
38 |
+
|
39 |
+
instruct_ds = IterableDataset.from_generator(generator=instruct_data_generator)
|
Experiments/llava_exp.ipynb
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "99576983-f881-47c8-8b5e-c6f561a93e71",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import transformers"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"id": "58ba19f2-4b91-4f90-a33d-4c1ed17e202a",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"from transformers import LlavaForConditionalGeneration, LlavaConfig, CLIPVisionConfig, PhiConfig\n",
|
21 |
+
"\n",
|
22 |
+
"# Initializing a CLIP-vision config\n",
|
23 |
+
"vision_config = CLIPVisionConfig()\n",
|
24 |
+
"\n",
|
25 |
+
"# Initializing a Llama config\n",
|
26 |
+
"text_config = PhiConfig()\n",
|
27 |
+
"\n",
|
28 |
+
"# Initializing a Llava llava-1.5-7b style configuration\n",
|
29 |
+
"configuration = LlavaConfig(vision_config, text_config)\n",
|
30 |
+
"\n",
|
31 |
+
"# Initializing a model from the llava-1.5-7b style configuration\n",
|
32 |
+
"model = LlavaForConditionalGeneration(configuration)\n",
|
33 |
+
"\n",
|
34 |
+
"# Accessing the model configuration\n",
|
35 |
+
"configuration = model.config"
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"execution_count": 5,
|
41 |
+
"id": "a806a07a-fe72-45a3-8ceb-8e942c6c845d",
|
42 |
+
"metadata": {},
|
43 |
+
"outputs": [
|
44 |
+
{
|
45 |
+
"data": {
|
46 |
+
"text/plain": [
|
47 |
+
"LlavaConfig {\n",
|
48 |
+
" \"ignore_index\": -100,\n",
|
49 |
+
" \"image_token_index\": 32000,\n",
|
50 |
+
" \"model_type\": \"llava\",\n",
|
51 |
+
" \"projector_hidden_act\": \"gelu\",\n",
|
52 |
+
" \"text_config\": {\n",
|
53 |
+
" \"embd_pdrop\": 0.0,\n",
|
54 |
+
" \"hidden_act\": \"gelu_new\",\n",
|
55 |
+
" \"hidden_size\": 2048,\n",
|
56 |
+
" \"intermediate_size\": 8192,\n",
|
57 |
+
" \"layer_norm_eps\": 1e-05,\n",
|
58 |
+
" \"model_type\": \"phi\",\n",
|
59 |
+
" \"num_hidden_layers\": 24,\n",
|
60 |
+
" \"partial_rotary_factor\": 0.5,\n",
|
61 |
+
" \"qk_layernorm\": false,\n",
|
62 |
+
" \"resid_pdrop\": 0.0,\n",
|
63 |
+
" \"vocab_size\": 51200\n",
|
64 |
+
" },\n",
|
65 |
+
" \"transformers_version\": \"4.36.2\",\n",
|
66 |
+
" \"vision_config\": {\n",
|
67 |
+
" \"hidden_size\": 768,\n",
|
68 |
+
" \"image_size\": 224,\n",
|
69 |
+
" \"intermediate_size\": 3072,\n",
|
70 |
+
" \"model_type\": \"clip_vision_model\",\n",
|
71 |
+
" \"num_attention_heads\": 12,\n",
|
72 |
+
" \"num_hidden_layers\": 12,\n",
|
73 |
+
" \"patch_size\": 32,\n",
|
74 |
+
" \"projection_dim\": 512\n",
|
75 |
+
" },\n",
|
76 |
+
" \"vision_feature_layer\": -2,\n",
|
77 |
+
" \"vision_feature_select_strategy\": \"default\",\n",
|
78 |
+
" \"vocab_size\": 32000\n",
|
79 |
+
"}"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
"execution_count": 5,
|
83 |
+
"metadata": {},
|
84 |
+
"output_type": "execute_result"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"source": [
|
88 |
+
"model.config"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 6,
|
94 |
+
"id": "79efbc6b-f005-4a5c-82a1-112fa37f1904",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"name": "stdout",
|
99 |
+
"output_type": "stream",
|
100 |
+
"text": [
|
101 |
+
"Cloning into 'llava-phi'...\n",
|
102 |
+
"remote: Enumerating objects: 151, done.\u001b[K\n",
|
103 |
+
"remote: Counting objects: 100% (151/151), done.\u001b[K\n",
|
104 |
+
"remote: Compressing objects: 100% (116/116), done.\u001b[K\n",
|
105 |
+
"remote: Total 151 (delta 36), reused 133 (delta 25), pack-reused 0\u001b[K\n",
|
106 |
+
"Receiving objects: 100% (151/151), 333.89 KiB | 112.00 KiB/s, done.\n",
|
107 |
+
"Resolving deltas: 100% (36/36), done.\n"
|
108 |
+
]
|
109 |
+
}
|
110 |
+
],
|
111 |
+
"source": [
|
112 |
+
"!git clone https://github.com/zhuyiche/llava-phi.git"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": null,
|
118 |
+
"id": "cf827184-f334-4d86-ace1-fe9c92f84d66",
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": []
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"metadata": {
|
125 |
+
"kernelspec": {
|
126 |
+
"display_name": "Python 3 (ipykernel)",
|
127 |
+
"language": "python",
|
128 |
+
"name": "python3"
|
129 |
+
},
|
130 |
+
"language_info": {
|
131 |
+
"codemirror_mode": {
|
132 |
+
"name": "ipython",
|
133 |
+
"version": 3
|
134 |
+
},
|
135 |
+
"file_extension": ".py",
|
136 |
+
"mimetype": "text/x-python",
|
137 |
+
"name": "python",
|
138 |
+
"nbconvert_exporter": "python",
|
139 |
+
"pygments_lexer": "ipython3",
|
140 |
+
"version": "3.10.12"
|
141 |
+
}
|
142 |
+
},
|
143 |
+
"nbformat": 4,
|
144 |
+
"nbformat_minor": 5
|
145 |
+
}
|
Experiments/multimodal_exp.ipynb
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 23,
|
6 |
+
"id": "d4bed9ef-4bff-4d61-a4f9-a585f377f136",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"from PIL import Image\n",
|
11 |
+
"import requests\n",
|
12 |
+
"\n",
|
13 |
+
"import torch\n",
|
14 |
+
"from torch import nn\n",
|
15 |
+
"from transformers import AutoProcessor, CLIPVisionModel, CLIPVisionConfig, CLIPPreTrainedModel\n",
|
16 |
+
"from transformers.models.clip.modeling_clip import CLIPVisionModelOutput, CLIPVisionTransformer\n",
|
17 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
18 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer\n",
|
19 |
+
"from typing import Optional, Union, Tuple"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 43,
|
25 |
+
"id": "952314f0-ee9d-45e7-85b8-1e3e44c1a2fd",
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [],
|
28 |
+
"source": [
|
29 |
+
"class VisionLanguageConnector(nn.Module):\n",
|
30 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
31 |
+
" super().__init__()\n",
|
32 |
+
" self.mlp = nn.Sequential(\n",
|
33 |
+
" nn.Linear(hidden_size, hidden_size, bias=False),\n",
|
34 |
+
" nn.GELU(),\n",
|
35 |
+
" nn.Linear(hidden_size, projection_dim, bias=False)\n",
|
36 |
+
" )\n",
|
37 |
+
"\n",
|
38 |
+
" def forward(self, x):\n",
|
39 |
+
" return self.mlp(x)\n",
|
40 |
+
" \n",
|
41 |
+
"class ClipWithProjection():\n",
|
42 |
+
" config_class = CLIPVisionConfig\n",
|
43 |
+
" main_input_name = \"pixel_values\"\n",
|
44 |
+
"\n",
|
45 |
+
" def __init__(self, hidden_size, projection_dim):\n",
|
46 |
+
" super().__init__()\n",
|
47 |
+
" \n",
|
48 |
+
" self.processor = AutoProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
49 |
+
" self.vision_model = CLIPVisionModel.from_pretrained(\"openai/clip-vit-base-patch32\")\n",
|
50 |
+
" self.vision_language_connector = VisionLanguageConnector(hidden_size, projection_dim)\n",
|
51 |
+
"\n",
|
52 |
+
" def forward(\n",
|
53 |
+
" self,\n",
|
54 |
+
" image = None,\n",
|
55 |
+
" output_attentions: Optional[bool] = None,\n",
|
56 |
+
" output_hidden_states: Optional[bool] = None,\n",
|
57 |
+
" return_dict: Optional[bool] = None,\n",
|
58 |
+
" ) -> Union[Tuple, CLIPVisionModelOutput]:\n",
|
59 |
+
" \n",
|
60 |
+
" pixel_values = self.processor(images=image, return_tensors=\"pt\")[\"pixel_values\"]\n",
|
61 |
+
" vision_outputs = self.vision_model(\n",
|
62 |
+
" pixel_values=pixel_values,\n",
|
63 |
+
" output_attentions=output_attentions,\n",
|
64 |
+
" output_hidden_states=output_hidden_states,\n",
|
65 |
+
" return_dict=return_dict,\n",
|
66 |
+
" )\n",
|
67 |
+
"\n",
|
68 |
+
" pooled_output = vision_outputs[1] # pooled_output\n",
|
69 |
+
"\n",
|
70 |
+
" image_embeds = self.vision_language_connector(pooled_output)\n",
|
71 |
+
"\n",
|
72 |
+
" return CLIPVisionModelOutput(\n",
|
73 |
+
" image_embeds=image_embeds,\n",
|
74 |
+
" last_hidden_state=vision_outputs.last_hidden_state,\n",
|
75 |
+
" hidden_states=vision_outputs.hidden_states,\n",
|
76 |
+
" attentions=vision_outputs.attentions,\n",
|
77 |
+
" )"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 44,
|
83 |
+
"id": "bd2889fe-be85-44a3-afe8-65b47f7a93c3",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [],
|
86 |
+
"source": [
|
87 |
+
"url = \"http://images.cocodataset.org/val2017/000000039769.jpg\"\n",
|
88 |
+
"image = Image.open(requests.get(url, stream=True).raw)"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 46,
|
94 |
+
"id": "17c72699-fe98-4b96-b63c-5c8ab7c1a65f",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [],
|
97 |
+
"source": [
|
98 |
+
"# model = ClipWithProjection(768, 512)\n",
|
99 |
+
"# model.forward(image)"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": 47,
|
105 |
+
"id": "70806156-38a9-45a2-bf9f-e72047a0173f",
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [],
|
108 |
+
"source": [
|
109 |
+
"class AudioLanguageConnector:\n",
|
110 |
+
" def __init__(self, projection_dim):\n",
|
111 |
+
" model_name = \"microsoft/phi-2\"\n",
|
112 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
113 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
114 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
115 |
+
"\n",
|
116 |
+
" def __call__(self, text):\n",
|
117 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
118 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
119 |
+
" return tokens\n",
|
120 |
+
" \n",
|
121 |
+
"\n",
|
122 |
+
"class WhisperWithProjection:\n",
|
123 |
+
" def __init__(self, projection_dim):\n",
|
124 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
125 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
126 |
+
" self.model.config.forced_decoder_ids = None\n",
|
127 |
+
" self.audio_language_connector = AudioLanguageConnector(projection_dim)\n",
|
128 |
+
" \n",
|
129 |
+
" def forward(self, audio):\n",
|
130 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
131 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
132 |
+
" return_tensors=\"pt\").input_features\n",
|
133 |
+
" # generate token ids\n",
|
134 |
+
" predicted_ids = self.model.generate(input_features)\n",
|
135 |
+
" # decode token ids to text \n",
|
136 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
137 |
+
"\n",
|
138 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
139 |
+
" return audio_embeddings"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "code",
|
144 |
+
"execution_count": 48,
|
145 |
+
"id": "79cc4d98-498b-4042-bd71-143b2477733d",
|
146 |
+
"metadata": {},
|
147 |
+
"outputs": [],
|
148 |
+
"source": [
|
149 |
+
"class TextModality:\n",
|
150 |
+
" def __init__(self, projection_dim):\n",
|
151 |
+
" model_name = \"microsoft/phi-2\"\n",
|
152 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
153 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
154 |
+
" self.phi2_tokenizer.max_length = projection_dim\n",
|
155 |
+
"\n",
|
156 |
+
"\n",
|
157 |
+
" def __call__(self, text):\n",
|
158 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
159 |
+
" return tokens"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": 77,
|
165 |
+
"id": "ba4c4772-923f-48e8-a4af-b7d9c192dd4b",
|
166 |
+
"metadata": {},
|
167 |
+
"outputs": [],
|
168 |
+
"source": [
|
169 |
+
"class MultiModalPhi2:\n",
|
170 |
+
" def __init__(self):\n",
|
171 |
+
" self.text_modality = TextModality(projection_dim=768)\n",
|
172 |
+
" self.whisper_w_proj = WhisperWithProjection(projection_dim=512)\n",
|
173 |
+
" self.clip_w_proj = ClipWithProjection(hidden_size=768, projection_dim=768)\n",
|
174 |
+
" self.llm = self.load_llm()\n",
|
175 |
+
"\n",
|
176 |
+
" def load_llm(self):\n",
|
177 |
+
" model_name = \"microsoft/phi-2\"\n",
|
178 |
+
" \n",
|
179 |
+
" bnb_config = BitsAndBytesConfig(\n",
|
180 |
+
" load_in_4bit=True,\n",
|
181 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
182 |
+
" bnb_4bit_compute_dtype=torch.float16)\n",
|
183 |
+
" \n",
|
184 |
+
" model = AutoModelForCausalLM.from_pretrained(\n",
|
185 |
+
" model_name,\n",
|
186 |
+
" quantization_config=bnb_config,\n",
|
187 |
+
" trust_remote_code=True,\n",
|
188 |
+
" device_map=\"cuda:0\"\n",
|
189 |
+
" )\n",
|
190 |
+
" model.config.use_cache = False\n",
|
191 |
+
" return model\n",
|
192 |
+
"\n",
|
193 |
+
" def forward(self, audio, image, text):\n",
|
194 |
+
" if text is not None:\n",
|
195 |
+
" text_embed = self.text_modality(text)[\"input_ids\"]\n",
|
196 |
+
" if audio is not None:\n",
|
197 |
+
" audio_embed = self.whisper_w_proj.forward(audio)[\"input_ids\"]\n",
|
198 |
+
" if image is not None:\n",
|
199 |
+
" image_embed = self.clip_w_proj.forward(image)[0]\n",
|
200 |
+
" print(text_embed.shape, text_embed.dtype)\n",
|
201 |
+
" print(audio_embed.shape, audio_embed.dtype)\n",
|
202 |
+
" print(image_embed.shape, image_embed.dtype)\n",
|
203 |
+
" \n",
|
204 |
+
" inputs = torch.concat([text_embed, audio_embed, image_embed], dim=1)\n",
|
205 |
+
" print(inputs.shape, inputs.dtype)\n",
|
206 |
+
" outputs = self.llm(inputs)\n",
|
207 |
+
"\n",
|
208 |
+
" return outputs \n",
|
209 |
+
" \n",
|
210 |
+
"\n",
|
211 |
+
" def generate(self, audio, text):\n",
|
212 |
+
" text_embeddings = self.text_modality(text)\n",
|
213 |
+
" audio_embeddings = self.whisper_w_proj.forward(audio)\n",
|
214 |
+
" inputs = torch.concat([text_embed[\"input_ids\"], audio_embed[\"input_ids\"]], dim=1)\n",
|
215 |
+
" \n",
|
216 |
+
" outputs = self.llm.generate(inputs, max_length=200)\n",
|
217 |
+
" text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
|
218 |
+
" print(text)"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
{
|
222 |
+
"cell_type": "code",
|
223 |
+
"execution_count": 74,
|
224 |
+
"id": "7ca694eb-8009-4eb9-9a4c-eac406ab9584",
|
225 |
+
"metadata": {},
|
226 |
+
"outputs": [],
|
227 |
+
"source": [
|
228 |
+
"from datasets import load_dataset\n",
|
229 |
+
"audio_ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
230 |
+
"audio = audio_ds[0][\"audio\"]"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 58,
|
236 |
+
"id": "37be28c5-4cc3-4471-b394-032c7602accc",
|
237 |
+
"metadata": {},
|
238 |
+
"outputs": [],
|
239 |
+
"source": [
|
240 |
+
"text = \"explain about the audio\""
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": 59,
|
246 |
+
"id": "c0705114-1670-4937-bc3e-3660e5a5d2c5",
|
247 |
+
"metadata": {},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"# image"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 78,
|
256 |
+
"id": "0d7e5b49-b4bd-477c-87b8-91ef70857677",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"name": "stderr",
|
261 |
+
"output_type": "stream",
|
262 |
+
"text": [
|
263 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
264 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"data": {
|
269 |
+
"application/vnd.jupyter.widget-view+json": {
|
270 |
+
"model_id": "733dc7b2208b4853a89aea49bff9a55c",
|
271 |
+
"version_major": 2,
|
272 |
+
"version_minor": 0
|
273 |
+
},
|
274 |
+
"text/plain": [
|
275 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
"metadata": {},
|
279 |
+
"output_type": "display_data"
|
280 |
+
}
|
281 |
+
],
|
282 |
+
"source": [
|
283 |
+
"model = MultiModalPhi2()"
|
284 |
+
]
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"execution_count": 79,
|
289 |
+
"id": "0b6471c4-4553-47f3-b38f-46057dcf80f2",
|
290 |
+
"metadata": {},
|
291 |
+
"outputs": [
|
292 |
+
{
|
293 |
+
"name": "stdout",
|
294 |
+
"output_type": "stream",
|
295 |
+
"text": [
|
296 |
+
"torch.Size([1, 5]) torch.int64\n",
|
297 |
+
"torch.Size([1, 33]) torch.int64\n",
|
298 |
+
"torch.Size([1, 768]) torch.float32\n",
|
299 |
+
"torch.Size([1, 806]) torch.float32\n"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"ename": "RuntimeError",
|
304 |
+
"evalue": "Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)",
|
305 |
+
"output_type": "error",
|
306 |
+
"traceback": [
|
307 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
308 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
309 |
+
"Cell \u001b[0;32mIn[79], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43maudio\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mimage\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m\n",
|
310 |
+
"Cell \u001b[0;32mIn[77], line 38\u001b[0m, in \u001b[0;36mMultiModalPhi2.forward\u001b[0;34m(self, audio, image, text)\u001b[0m\n\u001b[1;32m 36\u001b[0m inputs \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mconcat([text_embed, audio_embed, image_embed], dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m 37\u001b[0m \u001b[38;5;28mprint\u001b[39m(inputs\u001b[38;5;241m.\u001b[39mshape, inputs\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[0;32m---> 38\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mllm\u001b[49m\u001b[43m(\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 40\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
|
311 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
312 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
313 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
314 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:1049\u001b[0m, in \u001b[0;36mPhiForCausalLM.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1046\u001b[0m return_dict \u001b[38;5;241m=\u001b[39m return_dict \u001b[38;5;28;01mif\u001b[39;00m return_dict \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39muse_return_dict\n\u001b[1;32m 1048\u001b[0m \u001b[38;5;66;03m# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)\u001b[39;00m\n\u001b[0;32m-> 1049\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1050\u001b[0m \u001b[43m \u001b[49m\u001b[43minput_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minput_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1051\u001b[0m \u001b[43m \u001b[49m\u001b[43mattention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1052\u001b[0m \u001b[43m \u001b[49m\u001b[43mposition_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mposition_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1053\u001b[0m \u001b[43m \u001b[49m\u001b[43mpast_key_values\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpast_key_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1054\u001b[0m \u001b[43m \u001b[49m\u001b[43minputs_embeds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs_embeds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1055\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_cache\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_cache\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1056\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_attentions\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_attentions\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1057\u001b[0m \u001b[43m \u001b[49m\u001b[43moutput_hidden_states\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moutput_hidden_states\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1058\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_dict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_dict\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1059\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1061\u001b[0m hidden_states \u001b[38;5;241m=\u001b[39m outputs[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 1062\u001b[0m logits \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlm_head(hidden_states)\n",
|
315 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
316 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
317 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
318 |
+
"File \u001b[0;32m~/.cache/huggingface/modules/transformers_modules/microsoft/phi-2/85d00b03fee509307549d823fdd095473ba5197c/modeling_phi.py:893\u001b[0m, in \u001b[0;36mPhiModel.forward\u001b[0;34m(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 890\u001b[0m position_ids \u001b[38;5;241m=\u001b[39m position_ids\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m 892\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m inputs_embeds \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 893\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membed_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_ids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 895\u001b[0m inputs_embeds \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39membed_dropout(inputs_embeds)\n\u001b[1;32m 897\u001b[0m \u001b[38;5;66;03m# Attention mask.\u001b[39;00m\n",
|
319 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1518\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1516\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[1;32m 1517\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1518\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
320 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/module.py:1527\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1522\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1523\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1524\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1525\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1526\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1527\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1529\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 1530\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
|
321 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/accelerate/hooks.py:165\u001b[0m, in \u001b[0;36madd_hook_to_module.<locals>.new_forward\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m output \u001b[38;5;241m=\u001b[39m old_forward(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 164\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 165\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mold_forward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m module\u001b[38;5;241m.\u001b[39m_hf_hook\u001b[38;5;241m.\u001b[39mpost_forward(module, output)\n",
|
322 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:162\u001b[0m, in \u001b[0;36mEmbedding.forward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tensor:\n\u001b[0;32m--> 162\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 163\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 164\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
|
323 |
+
"File \u001b[0;32m~/miniconda3/envs/torchenv/lib/python3.10/site-packages/torch/nn/functional.py:2233\u001b[0m, in \u001b[0;36membedding\u001b[0;34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[0m\n\u001b[1;32m 2227\u001b[0m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[1;32m 2228\u001b[0m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[1;32m 2229\u001b[0m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[1;32m 2230\u001b[0m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[1;32m 2231\u001b[0m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[1;32m 2232\u001b[0m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[0;32m-> 2233\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
|
324 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)"
|
325 |
+
]
|
326 |
+
}
|
327 |
+
],
|
328 |
+
"source": [
|
329 |
+
"model.forward(audio, image, text)"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
{
|
333 |
+
"cell_type": "code",
|
334 |
+
"execution_count": null,
|
335 |
+
"id": "4ca96caf-82e2-4f07-87b3-8654dfdc89aa",
|
336 |
+
"metadata": {},
|
337 |
+
"outputs": [],
|
338 |
+
"source": []
|
339 |
+
}
|
340 |
+
],
|
341 |
+
"metadata": {
|
342 |
+
"kernelspec": {
|
343 |
+
"display_name": "Python 3 (ipykernel)",
|
344 |
+
"language": "python",
|
345 |
+
"name": "python3"
|
346 |
+
},
|
347 |
+
"language_info": {
|
348 |
+
"codemirror_mode": {
|
349 |
+
"name": "ipython",
|
350 |
+
"version": 3
|
351 |
+
},
|
352 |
+
"file_extension": ".py",
|
353 |
+
"mimetype": "text/x-python",
|
354 |
+
"name": "python",
|
355 |
+
"nbconvert_exporter": "python",
|
356 |
+
"pygments_lexer": "ipython3",
|
357 |
+
"version": "3.10.12"
|
358 |
+
}
|
359 |
+
},
|
360 |
+
"nbformat": 4,
|
361 |
+
"nbformat_minor": 5
|
362 |
+
}
|
Experiments/pretrain_data_check.ipynb
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 5,
|
6 |
+
"id": "61c272f2-edbe-4b7d-8fec-3ab431400cd3",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"import json"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"cell_type": "code",
|
15 |
+
"execution_count": 2,
|
16 |
+
"id": "e9dfd7d7-1685-4fc7-bbb9-3905c32d8ba1",
|
17 |
+
"metadata": {},
|
18 |
+
"outputs": [],
|
19 |
+
"source": [
|
20 |
+
"with open(\"metadata.json\", \"rb\") as f:\n",
|
21 |
+
" metadata = json.load(f)"
|
22 |
+
]
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"cell_type": "code",
|
26 |
+
"execution_count": 4,
|
27 |
+
"id": "70bdba48-db01-42ac-8d89-edc69d7d7672",
|
28 |
+
"metadata": {},
|
29 |
+
"outputs": [
|
30 |
+
{
|
31 |
+
"data": {
|
32 |
+
"text/plain": [
|
33 |
+
"595375"
|
34 |
+
]
|
35 |
+
},
|
36 |
+
"execution_count": 4,
|
37 |
+
"metadata": {},
|
38 |
+
"output_type": "execute_result"
|
39 |
+
}
|
40 |
+
],
|
41 |
+
"source": [
|
42 |
+
"len(metadata)"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "code",
|
47 |
+
"execution_count": 14,
|
48 |
+
"id": "59e193cc-0dd8-4f7e-959a-fbad0133d76c",
|
49 |
+
"metadata": {},
|
50 |
+
"outputs": [],
|
51 |
+
"source": [
|
52 |
+
"with open(\"blip_laion_cc_sbu_558k.jsonblip_laion_cc_sbu_558k.json\", \"rb\") as f:\n",
|
53 |
+
" data = json.load(f)"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 7,
|
59 |
+
"id": "f3157f41-269b-4f7a-b3ba-9be711babe02",
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"data": {
|
64 |
+
"text/plain": [
|
65 |
+
"{'id': '004539375',\n",
|
66 |
+
" 'image': '00453/004539375.jpg',\n",
|
67 |
+
" 'conversations': [{'from': 'human',\n",
|
68 |
+
" 'value': 'Render a clear and concise summary of the photo.\\n<image>'},\n",
|
69 |
+
" {'from': 'gpt',\n",
|
70 |
+
" 'value': 'select luxury furniture 3 - inch gel memory foam mattress topper'}]}"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
"execution_count": 7,
|
74 |
+
"metadata": {},
|
75 |
+
"output_type": "execute_result"
|
76 |
+
}
|
77 |
+
],
|
78 |
+
"source": [
|
79 |
+
"data[0]"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 8,
|
85 |
+
"id": "50d8a051-1526-47dd-ad71-d3c66f7bd34e",
|
86 |
+
"metadata": {},
|
87 |
+
"outputs": [
|
88 |
+
{
|
89 |
+
"data": {
|
90 |
+
"text/plain": [
|
91 |
+
"{'id': '004374662',\n",
|
92 |
+
" 'image': '00437/004374662.jpg',\n",
|
93 |
+
" 'conversations': [{'from': 'human',\n",
|
94 |
+
" 'value': 'Give a brief description of the image.\\n<image>'},\n",
|
95 |
+
" {'from': 'gpt', 'value': 'the north face duffel bag camo large'}]}"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
"execution_count": 8,
|
99 |
+
"metadata": {},
|
100 |
+
"output_type": "execute_result"
|
101 |
+
}
|
102 |
+
],
|
103 |
+
"source": [
|
104 |
+
"data[234]"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"execution_count": 17,
|
110 |
+
"id": "2e6d5664-4583-49a6-93cc-079ee2d1ff6c",
|
111 |
+
"metadata": {},
|
112 |
+
"outputs": [
|
113 |
+
{
|
114 |
+
"data": {
|
115 |
+
"text/plain": [
|
116 |
+
"558128"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
"execution_count": 17,
|
120 |
+
"metadata": {},
|
121 |
+
"output_type": "execute_result"
|
122 |
+
}
|
123 |
+
],
|
124 |
+
"source": [
|
125 |
+
"len(data)"
|
126 |
+
]
|
127 |
+
},
|
128 |
+
{
|
129 |
+
"cell_type": "code",
|
130 |
+
"execution_count": 10,
|
131 |
+
"id": "11ed106d-6bef-482c-a456-5eaaf2025534",
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [
|
134 |
+
{
|
135 |
+
"data": {
|
136 |
+
"text/plain": [
|
137 |
+
"{'id': 'GCC_train_001749371',\n",
|
138 |
+
" 'image': 'GCC_train_001749371.jpg',\n",
|
139 |
+
" 'caption': 'if you are dreaming of simpler or off - the - grid living , a yurt is a fantastic option',\n",
|
140 |
+
" 'blip_caption': 'a white and tan yurt sitting on a dirt road',\n",
|
141 |
+
" 'url': 'https://i.pinimg.com/736x/14/7b/64/147b64467ee966d9a578097bb70475ad--yurt-kits-small-space-living.jpg'}"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
"execution_count": 10,
|
145 |
+
"metadata": {},
|
146 |
+
"output_type": "execute_result"
|
147 |
+
}
|
148 |
+
],
|
149 |
+
"source": [
|
150 |
+
"metadata[67]"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 15,
|
156 |
+
"id": "ce8adcec-2499-4be3-be1d-7313fe54e96a",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"data": {
|
161 |
+
"text/plain": [
|
162 |
+
"{'id': '000466761',\n",
|
163 |
+
" 'image': '00046/000466761.jpg',\n",
|
164 |
+
" 'conversations': [{'from': 'human',\n",
|
165 |
+
" 'value': '<image>\\nProvide a brief description of the given image.'},\n",
|
166 |
+
" {'from': 'gpt',\n",
|
167 |
+
" 'value': 'a clipboard and a pen with the words public health emergency next to it on a white table'}]}"
|
168 |
+
]
|
169 |
+
},
|
170 |
+
"execution_count": 15,
|
171 |
+
"metadata": {},
|
172 |
+
"output_type": "execute_result"
|
173 |
+
}
|
174 |
+
],
|
175 |
+
"source": [
|
176 |
+
"data[67]"
|
177 |
+
]
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"cell_type": "code",
|
181 |
+
"execution_count": 16,
|
182 |
+
"id": "068313b6-6379-4ca2-892c-682634d3581e",
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"data": {
|
187 |
+
"text/plain": [
|
188 |
+
"list"
|
189 |
+
]
|
190 |
+
},
|
191 |
+
"execution_count": 16,
|
192 |
+
"metadata": {},
|
193 |
+
"output_type": "execute_result"
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"type(data)"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 24,
|
203 |
+
"id": "9ec33b51-4a0b-4a1e-81f7-2fda7cddb25f",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"sample_data = data[:200000]"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 25,
|
213 |
+
"id": "095685e5-40f1-4d84-8280-ef74fa56c5a2",
|
214 |
+
"metadata": {},
|
215 |
+
"outputs": [
|
216 |
+
{
|
217 |
+
"data": {
|
218 |
+
"text/plain": [
|
219 |
+
"200000"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
"execution_count": 25,
|
223 |
+
"metadata": {},
|
224 |
+
"output_type": "execute_result"
|
225 |
+
}
|
226 |
+
],
|
227 |
+
"source": [
|
228 |
+
"len(sample_data)"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
{
|
232 |
+
"cell_type": "code",
|
233 |
+
"execution_count": 26,
|
234 |
+
"id": "ffbad552-23fd-475f-8e9a-7118bcc4f51e",
|
235 |
+
"metadata": {},
|
236 |
+
"outputs": [],
|
237 |
+
"source": [
|
238 |
+
"with open(\"llava-phi/pretrain_data/blip_sample.json\", \"w\") as f:\n",
|
239 |
+
" json.dump(sample_data, f)"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": 27,
|
245 |
+
"id": "69a05d25-6f3b-40c0-a3b5-e185ff526471",
|
246 |
+
"metadata": {},
|
247 |
+
"outputs": [],
|
248 |
+
"source": [
|
249 |
+
"with open(\"llava-phi/pretrain_data/blip_sample.json\", \"rb\") as f:\n",
|
250 |
+
" sample = json.load(f)"
|
251 |
+
]
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "code",
|
255 |
+
"execution_count": 28,
|
256 |
+
"id": "200eea06-dfd6-4b3a-bb91-82af7d363951",
|
257 |
+
"metadata": {},
|
258 |
+
"outputs": [
|
259 |
+
{
|
260 |
+
"data": {
|
261 |
+
"text/plain": [
|
262 |
+
"200000"
|
263 |
+
]
|
264 |
+
},
|
265 |
+
"execution_count": 28,
|
266 |
+
"metadata": {},
|
267 |
+
"output_type": "execute_result"
|
268 |
+
}
|
269 |
+
],
|
270 |
+
"source": [
|
271 |
+
"len(sample)"
|
272 |
+
]
|
273 |
+
},
|
274 |
+
{
|
275 |
+
"cell_type": "code",
|
276 |
+
"execution_count": null,
|
277 |
+
"id": "f86caa1e-edea-4a9c-934f-5420ede80d0d",
|
278 |
+
"metadata": {},
|
279 |
+
"outputs": [],
|
280 |
+
"source": []
|
281 |
+
}
|
282 |
+
],
|
283 |
+
"metadata": {
|
284 |
+
"kernelspec": {
|
285 |
+
"display_name": "Python 3 (ipykernel)",
|
286 |
+
"language": "python",
|
287 |
+
"name": "python3"
|
288 |
+
},
|
289 |
+
"language_info": {
|
290 |
+
"codemirror_mode": {
|
291 |
+
"name": "ipython",
|
292 |
+
"version": 3
|
293 |
+
},
|
294 |
+
"file_extension": ".py",
|
295 |
+
"mimetype": "text/x-python",
|
296 |
+
"name": "python",
|
297 |
+
"nbconvert_exporter": "python",
|
298 |
+
"pygments_lexer": "ipython3",
|
299 |
+
"version": "3.10.12"
|
300 |
+
}
|
301 |
+
},
|
302 |
+
"nbformat": 4,
|
303 |
+
"nbformat_minor": 5
|
304 |
+
}
|
Experiments/whispher_exp.ipynb
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 9,
|
6 |
+
"id": "bb4dd66b-0c17-48d4-9d34-f48cece2feb5",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"# !pip install soundfile\n",
|
11 |
+
"# !pip install librosa"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 1,
|
17 |
+
"id": "6e9386ea-4862-4f5b-a02f-d656e1a5ab9e",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from transformers import WhisperProcessor, WhisperForConditionalGeneration\n",
|
22 |
+
"from datasets import load_dataset"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"cell_type": "code",
|
27 |
+
"execution_count": 2,
|
28 |
+
"id": "914ab2b4-389d-4c48-8d1d-1250356646ac",
|
29 |
+
"metadata": {},
|
30 |
+
"outputs": [],
|
31 |
+
"source": [
|
32 |
+
"# load model and processor\n",
|
33 |
+
"processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
34 |
+
"model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
35 |
+
"model.config.forced_decoder_ids = None\n",
|
36 |
+
"\n",
|
37 |
+
"# load dummy dataset and read audio files\n",
|
38 |
+
"ds = load_dataset(\"hf-internal-testing/librispeech_asr_dummy\", \"clean\", split=\"validation\")\n",
|
39 |
+
"sample = ds[0][\"audio\"]"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 3,
|
45 |
+
"id": "2b299bab-1228-48d9-a8a5-3d5b6c52162d",
|
46 |
+
"metadata": {},
|
47 |
+
"outputs": [
|
48 |
+
{
|
49 |
+
"data": {
|
50 |
+
"text/plain": [
|
51 |
+
"{'path': '/home/ravi.naik/.cache/huggingface/datasets/downloads/extracted/431c2c946d216530b2666a0e7ffa5ac3f5b3da89dd28858a9de6c78fae7caa4a/dev_clean/1272/128104/1272-128104-0000.flac',\n",
|
52 |
+
" 'array': array([0.00238037, 0.0020752 , 0.00198364, ..., 0.00042725, 0.00057983,\n",
|
53 |
+
" 0.0010376 ]),\n",
|
54 |
+
" 'sampling_rate': 16000}"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
"execution_count": 3,
|
58 |
+
"metadata": {},
|
59 |
+
"output_type": "execute_result"
|
60 |
+
}
|
61 |
+
],
|
62 |
+
"source": [
|
63 |
+
"sample"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 4,
|
69 |
+
"id": "b7e570a1-cf5c-450c-a7b6-49b45a10d2df",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"input_features = processor(sample[\"array\"], sampling_rate=sample[\"sampling_rate\"], return_tensors=\"pt\").input_features "
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 5,
|
79 |
+
"id": "584e920b-a7fd-402d-95dd-3b9128cd34bb",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [],
|
82 |
+
"source": [
|
83 |
+
"# generate token ids\n",
|
84 |
+
"predicted_ids = model.generate(input_features)\n",
|
85 |
+
"# decode token ids to text\n",
|
86 |
+
"transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)\n",
|
87 |
+
"\n",
|
88 |
+
"transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)"
|
89 |
+
]
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"cell_type": "code",
|
93 |
+
"execution_count": 6,
|
94 |
+
"id": "b27ab660-861b-49d1-81f9-f51cb7f9d8d8",
|
95 |
+
"metadata": {},
|
96 |
+
"outputs": [
|
97 |
+
{
|
98 |
+
"data": {
|
99 |
+
"text/plain": [
|
100 |
+
"[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
"execution_count": 6,
|
104 |
+
"metadata": {},
|
105 |
+
"output_type": "execute_result"
|
106 |
+
}
|
107 |
+
],
|
108 |
+
"source": [
|
109 |
+
"transcription"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": 3,
|
115 |
+
"id": "eca553b8-68f6-493d-b567-3d526b49ae1b",
|
116 |
+
"metadata": {},
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"import torch\n",
|
120 |
+
"from torch import nn"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 4,
|
126 |
+
"id": "c619a4cf-9068-4e4d-8139-e16d15345f4f",
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer"
|
131 |
+
]
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"execution_count": 5,
|
136 |
+
"id": "47d5b1ff-ab0f-4d11-af64-d2fa2be39286",
|
137 |
+
"metadata": {},
|
138 |
+
"outputs": [
|
139 |
+
{
|
140 |
+
"name": "stderr",
|
141 |
+
"output_type": "stream",
|
142 |
+
"text": [
|
143 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
144 |
+
]
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"source": [
|
148 |
+
"model_name = \"microsoft/phi-2\"\n",
|
149 |
+
"phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
150 |
+
"phi2_tokenizer.pad_token = phi2_tokenizer.eos_token"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"execution_count": 6,
|
156 |
+
"id": "0b36b3f0-db5b-4029-9072-0a53bcab315a",
|
157 |
+
"metadata": {},
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"ename": "NameError",
|
161 |
+
"evalue": "name 'transcription' is not defined",
|
162 |
+
"output_type": "error",
|
163 |
+
"traceback": [
|
164 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
165 |
+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
166 |
+
"Cell \u001b[0;32mIn[6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m tokens \u001b[38;5;241m=\u001b[39m phi2_tokenizer(\u001b[38;5;241m*\u001b[39m\u001b[43mtranscription\u001b[49m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, return_attention_mask\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m)\n",
|
167 |
+
"\u001b[0;31mNameError\u001b[0m: name 'transcription' is not defined"
|
168 |
+
]
|
169 |
+
}
|
170 |
+
],
|
171 |
+
"source": [
|
172 |
+
"tokens = phi2_tokenizer(*transcription, return_tensors=\"pt\", return_attention_mask=False)"
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "code",
|
177 |
+
"execution_count": 22,
|
178 |
+
"id": "91f6d3d3-bb00-434f-a91e-6952375890d0",
|
179 |
+
"metadata": {},
|
180 |
+
"outputs": [
|
181 |
+
{
|
182 |
+
"data": {
|
183 |
+
"text/plain": [
|
184 |
+
"{'input_ids': tensor([[ 1770, 13, 2264, 346, 353, 318, 262, 46329, 286, 262,\n",
|
185 |
+
" 3504, 6097, 290, 356, 389, 9675, 284, 7062, 465, 21443,\n",
|
186 |
+
" 13]])}"
|
187 |
+
]
|
188 |
+
},
|
189 |
+
"execution_count": 22,
|
190 |
+
"metadata": {},
|
191 |
+
"output_type": "execute_result"
|
192 |
+
}
|
193 |
+
],
|
194 |
+
"source": [
|
195 |
+
"tokens"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 12,
|
201 |
+
"id": "533191d9-4b3b-417a-918d-6fe854f24b50",
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [
|
204 |
+
{
|
205 |
+
"name": "stderr",
|
206 |
+
"output_type": "stream",
|
207 |
+
"text": [
|
208 |
+
"A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
|
209 |
+
"- configuration_phi.py\n",
|
210 |
+
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
211 |
+
]
|
212 |
+
},
|
213 |
+
{
|
214 |
+
"data": {
|
215 |
+
"application/vnd.jupyter.widget-view+json": {
|
216 |
+
"model_id": "2a65a119388b4cb4b123b532176e786e",
|
217 |
+
"version_major": 2,
|
218 |
+
"version_minor": 0
|
219 |
+
},
|
220 |
+
"text/plain": [
|
221 |
+
"modeling_phi.py: 0%| | 0.00/62.7k [00:00<?, ?B/s]"
|
222 |
+
]
|
223 |
+
},
|
224 |
+
"metadata": {},
|
225 |
+
"output_type": "display_data"
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"name": "stderr",
|
229 |
+
"output_type": "stream",
|
230 |
+
"text": [
|
231 |
+
"A new version of the following files was downloaded from https://huggingface.co/microsoft/phi-2:\n",
|
232 |
+
"- modeling_phi.py\n",
|
233 |
+
". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"data": {
|
238 |
+
"application/vnd.jupyter.widget-view+json": {
|
239 |
+
"model_id": "7183811844304c16b72d53fe11098a74",
|
240 |
+
"version_major": 2,
|
241 |
+
"version_minor": 0
|
242 |
+
},
|
243 |
+
"text/plain": [
|
244 |
+
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
"metadata": {},
|
248 |
+
"output_type": "display_data"
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"data": {
|
252 |
+
"application/vnd.jupyter.widget-view+json": {
|
253 |
+
"model_id": "3e78fe144e8f42139a4d7a1830dbf192",
|
254 |
+
"version_major": 2,
|
255 |
+
"version_minor": 0
|
256 |
+
},
|
257 |
+
"text/plain": [
|
258 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
"metadata": {},
|
262 |
+
"output_type": "display_data"
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
267 |
+
" load_in_4bit=True,\n",
|
268 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
269 |
+
" bnb_4bit_compute_dtype=torch.float16,\n",
|
270 |
+
")\n",
|
271 |
+
"\n",
|
272 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
273 |
+
" model_name,\n",
|
274 |
+
" quantization_config=bnb_config,\n",
|
275 |
+
" trust_remote_code=True,\n",
|
276 |
+
" device_map=\"cuda:0\"\n",
|
277 |
+
")\n",
|
278 |
+
"model.config.use_cache = False"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"execution_count": 19,
|
284 |
+
"id": "155c054a-a00f-4ed5-bfff-1ad64889e7f1",
|
285 |
+
"metadata": {},
|
286 |
+
"outputs": [
|
287 |
+
{
|
288 |
+
"data": {
|
289 |
+
"text/plain": [
|
290 |
+
"[' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.\\n']"
|
291 |
+
]
|
292 |
+
},
|
293 |
+
"execution_count": 19,
|
294 |
+
"metadata": {},
|
295 |
+
"output_type": "execute_result"
|
296 |
+
}
|
297 |
+
],
|
298 |
+
"source": [
|
299 |
+
"phi2_tokenizer.batch_decode(model.generate(**tokens))"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "code",
|
304 |
+
"execution_count": 7,
|
305 |
+
"id": "04f940c9-586d-4937-ae31-cc0f96d33e92",
|
306 |
+
"metadata": {},
|
307 |
+
"outputs": [],
|
308 |
+
"source": [
|
309 |
+
"class AudioLanguageConnector:\n",
|
310 |
+
" def __init__(self):\n",
|
311 |
+
" model_name = \"microsoft/phi-2\"\n",
|
312 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
313 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
314 |
+
"\n",
|
315 |
+
" def __call__(self, text):\n",
|
316 |
+
" text = f\"<audio_start> {text} <audio_end>\"\n",
|
317 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
318 |
+
" return tokens\n",
|
319 |
+
" \n",
|
320 |
+
"\n",
|
321 |
+
"class WhisperWithProjection:\n",
|
322 |
+
" def __init__(self):\n",
|
323 |
+
" self.processor = WhisperProcessor.from_pretrained(\"openai/whisper-tiny\")\n",
|
324 |
+
" self.model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-tiny\")\n",
|
325 |
+
" self.model.config.forced_decoder_ids = None\n",
|
326 |
+
" self.audio_language_connector = AudioLanguageConnector()\n",
|
327 |
+
" \n",
|
328 |
+
" def forward(self, audio):\n",
|
329 |
+
" input_features = self.processor(audio[\"array\"],\n",
|
330 |
+
" sampling_rate=audio[\"sampling_rate\"],\n",
|
331 |
+
" return_tensors=\"pt\").input_features\n",
|
332 |
+
" # generate token ids\n",
|
333 |
+
" predicted_ids = self.model.generate(input_features)\n",
|
334 |
+
" # decode token ids to text \n",
|
335 |
+
" transcription = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)\n",
|
336 |
+
"\n",
|
337 |
+
" audio_embeddings = self.audio_language_connector(transcription)\n",
|
338 |
+
" return audio_embeddings"
|
339 |
+
]
|
340 |
+
},
|
341 |
+
{
|
342 |
+
"cell_type": "code",
|
343 |
+
"execution_count": 8,
|
344 |
+
"id": "2b1f8f44-bfe6-413c-9e32-c38fa5517981",
|
345 |
+
"metadata": {},
|
346 |
+
"outputs": [],
|
347 |
+
"source": [
|
348 |
+
"class TextModality:\n",
|
349 |
+
" def __init__(self):\n",
|
350 |
+
" model_name = \"microsoft/phi-2\"\n",
|
351 |
+
" self.phi2_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
|
352 |
+
" self.phi2_tokenizer.pad_token = self.phi2_tokenizer.eos_token\n",
|
353 |
+
"\n",
|
354 |
+
" def __call__(self, text):\n",
|
355 |
+
" tokens = self.phi2_tokenizer(text, return_tensors=\"pt\", return_attention_mask=False)\n",
|
356 |
+
" return tokens"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": 15,
|
362 |
+
"id": "21c51648-abb6-4bbd-b4c1-509967a69337",
|
363 |
+
"metadata": {},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"class MultiModalPhi2:\n",
|
367 |
+
" def __init__(self):\n",
|
368 |
+
" self.text_modality = TextModality()\n",
|
369 |
+
" self.whisper_w_proj = WhisperWithProjection()\n",
|
370 |
+
" self.llm = self.load_llm()\n",
|
371 |
+
"\n",
|
372 |
+
" def load_llm(self):\n",
|
373 |
+
" bnb_config = BitsAndBytesConfig(\n",
|
374 |
+
" load_in_4bit=True,\n",
|
375 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
376 |
+
" bnb_4bit_compute_dtype=torch.float16)\n",
|
377 |
+
" \n",
|
378 |
+
" model = AutoModelForCausalLM.from_pretrained(\n",
|
379 |
+
" model_name,\n",
|
380 |
+
" quantization_config=bnb_config,\n",
|
381 |
+
" trust_remote_code=True,\n",
|
382 |
+
" device_map=\"cuda:0\"\n",
|
383 |
+
" )\n",
|
384 |
+
" model.config.use_cache = False\n",
|
385 |
+
" return model\n",
|
386 |
+
"\n",
|
387 |
+
" def generate(self, audio, text):\n",
|
388 |
+
" text_embeddings = self.text_modality(text)\n",
|
389 |
+
" audio_embeddings = self.whisper_w_proj.forward(audio)\n",
|
390 |
+
" inputs = torch.concat([text_embeddings[\"input_ids\"], audio_embeddings[\"input_ids\"]], dim=1)\n",
|
391 |
+
" \n",
|
392 |
+
" # outputs = self.llm.generate(inputs, max_length=200)\n",
|
393 |
+
" outputs = self.llm(inputs)\n",
|
394 |
+
" return outputs\n",
|
395 |
+
" \n",
|
396 |
+
" # text = self.text_modality.phi2_tokenizer.batch_decode(outputs)[0]\n",
|
397 |
+
" # print(text)"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 16,
|
403 |
+
"id": "472a00cb-bae9-4c09-a0ef-bc57881b5e2c",
|
404 |
+
"metadata": {},
|
405 |
+
"outputs": [
|
406 |
+
{
|
407 |
+
"name": "stderr",
|
408 |
+
"output_type": "stream",
|
409 |
+
"text": [
|
410 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n",
|
411 |
+
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
|
412 |
+
]
|
413 |
+
},
|
414 |
+
{
|
415 |
+
"data": {
|
416 |
+
"application/vnd.jupyter.widget-view+json": {
|
417 |
+
"model_id": "2236e6b1e26d444fa3d48181ba1a6cf9",
|
418 |
+
"version_major": 2,
|
419 |
+
"version_minor": 0
|
420 |
+
},
|
421 |
+
"text/plain": [
|
422 |
+
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
"metadata": {},
|
426 |
+
"output_type": "display_data"
|
427 |
+
}
|
428 |
+
],
|
429 |
+
"source": [
|
430 |
+
"multi_modal_phi = MultiModalPhi2()"
|
431 |
+
]
|
432 |
+
},
|
433 |
+
{
|
434 |
+
"cell_type": "code",
|
435 |
+
"execution_count": 17,
|
436 |
+
"id": "c350f2d3-0929-4c46-b63d-ff92dea437f3",
|
437 |
+
"metadata": {},
|
438 |
+
"outputs": [
|
439 |
+
{
|
440 |
+
"data": {
|
441 |
+
"text/plain": [
|
442 |
+
"CausalLMOutputWithPast(loss={'logits': tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
|
443 |
+
" [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
|
444 |
+
" [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
|
445 |
+
" ...,\n",
|
446 |
+
" [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
|
447 |
+
" [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
|
448 |
+
" [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
|
449 |
+
" grad_fn=<ToCopyBackward0>)}, logits=tensor([[[ 6.9531, 9.9375, 7.0234, ..., 2.0020, 2.0020, 2.0000],\n",
|
450 |
+
" [ 8.9062, 12.1172, 7.5977, ..., -1.2012, -1.2012, -1.2012],\n",
|
451 |
+
" [ 7.0273, 5.3477, 3.6328, ..., -4.2070, -4.2070, -4.2070],\n",
|
452 |
+
" ...,\n",
|
453 |
+
" [ 7.0234, 7.4414, 9.1016, ..., 1.0117, 1.0127, 1.0117],\n",
|
454 |
+
" [ 9.4531, 10.0391, 9.7578, ..., 0.0776, 0.0775, 0.0764],\n",
|
455 |
+
" [ 8.0703, 6.6445, 5.5156, ..., -1.9268, -1.9268, -1.9277]]],\n",
|
456 |
+
" grad_fn=<ToCopyBackward0>), past_key_values=None, hidden_states=None, attentions=None)"
|
457 |
+
]
|
458 |
+
},
|
459 |
+
"execution_count": 17,
|
460 |
+
"metadata": {},
|
461 |
+
"output_type": "execute_result"
|
462 |
+
}
|
463 |
+
],
|
464 |
+
"source": [
|
465 |
+
"audio = sample\n",
|
466 |
+
"text = \"explain about the audio\"\n",
|
467 |
+
"multi_modal_phi.generate(audio, text)"
|
468 |
+
]
|
469 |
+
},
|
470 |
+
{
|
471 |
+
"cell_type": "code",
|
472 |
+
"execution_count": null,
|
473 |
+
"id": "46aa9c66-a5bb-4760-8895-92673f49345f",
|
474 |
+
"metadata": {},
|
475 |
+
"outputs": [],
|
476 |
+
"source": []
|
477 |
+
}
|
478 |
+
],
|
479 |
+
"metadata": {
|
480 |
+
"kernelspec": {
|
481 |
+
"display_name": "Python 3 (ipykernel)",
|
482 |
+
"language": "python",
|
483 |
+
"name": "python3"
|
484 |
+
},
|
485 |
+
"language_info": {
|
486 |
+
"codemirror_mode": {
|
487 |
+
"name": "ipython",
|
488 |
+
"version": 3
|
489 |
+
},
|
490 |
+
"file_extension": ".py",
|
491 |
+
"mimetype": "text/x-python",
|
492 |
+
"name": "python",
|
493 |
+
"nbconvert_exporter": "python",
|
494 |
+
"pygments_lexer": "ipython3",
|
495 |
+
"version": "3.10.12"
|
496 |
+
}
|
497 |
+
},
|
498 |
+
"nbformat": 4,
|
499 |
+
"nbformat_minor": 5
|
500 |
+
}
|
README.md
CHANGED
@@ -1,13 +1,44 @@
|
|
1 |
---
|
2 |
title: MultiModal Phi2
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: MultiModal Phi2
|
3 |
+
emoji: 🚀
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.35.2
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
11 |
---
|
12 |
+
## Phi2 : Multimodal Finetuning
|
13 |
+
### Details
|
14 |
+
1. LLM Backbone: Phi2
|
15 |
+
2. Vision Tower: clip-vit-large-patch14-336
|
16 |
+
3. Audio Model: Whisper
|
17 |
+
4. Pretraining Dataset: LAION-CC-SBU dataset with BLIP captions(200k samples)
|
18 |
+
5. Finetuning Dataset: Instruct 150k dataset based on COCO
|
19 |
|
20 |
+
### Design
|
21 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/56df24cd-2681-4e17-ab64-9652f609b15f)
|
22 |
+
|
23 |
+
### Pretraining
|
24 |
+
#### Training Loss Curve
|
25 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/b6c37a95-0a56-4b52-8719-3ff56dc1b703)
|
26 |
+
|
27 |
+
#### Learing Rate
|
28 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/44d9a11b-b28d-47e1-ba1d-d6dc22ebe748)
|
29 |
+
|
30 |
+
#### Training Logs
|
31 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/76543d98-d9fe-4c1a-ac47-3d06e48053ad)
|
32 |
+
|
33 |
+
### Finetuning
|
34 |
+
#### Training Loss Curve
|
35 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/45ef40bd-fae5-4cfe-a522-c0eed2833230)
|
36 |
+
|
37 |
+
#### Learing Rate
|
38 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/df60ee62-a537-4e36-a7f7-f7111e101162)
|
39 |
+
|
40 |
+
#### Training Logs
|
41 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/2747acce-bc99-4c37-a05a-d5e81cb9aa9d)
|
42 |
+
|
43 |
+
### Results
|
44 |
+
![image](https://github.com/RaviNaik/ERA-CAPSTONE/assets/23289802/f12a9f04-df32-413e-b957-774c30381b2b)
|
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from inference.main import MultiModalPhi2
|
4 |
+
|
5 |
+
messages = []
|
6 |
+
|
7 |
+
multimodal_phi2 = MultiModalPhi2(
|
8 |
+
modelname_or_path="RaviNaik/Llava-Phi2",
|
9 |
+
temperature=0.2,
|
10 |
+
max_new_tokens=1024,
|
11 |
+
device="cpu",
|
12 |
+
)
|
13 |
+
|
14 |
+
|
15 |
+
def add_content(chatbot, text, image, audio_upload, audio_mic) -> gr.Chatbot:
|
16 |
+
textflag, imageflag, audioflag = False, False, False
|
17 |
+
if text not in ["", None]:
|
18 |
+
chatbot.append((text, None))
|
19 |
+
textflag = True
|
20 |
+
if image is not None:
|
21 |
+
chatbot.append(((image,), None))
|
22 |
+
imageflag = True
|
23 |
+
if audio_mic is not None:
|
24 |
+
chatbot.append(((audio_mic,), None))
|
25 |
+
audioflag = True
|
26 |
+
else:
|
27 |
+
if audio_upload is not None:
|
28 |
+
chatbot.append(((audio_upload,), None))
|
29 |
+
audioflag = True
|
30 |
+
if not any([textflag, imageflag, audioflag]):
|
31 |
+
# Raise an error if neither text nor file is provided
|
32 |
+
raise gr.Error("Enter a valid text, image or audio")
|
33 |
+
return chatbot
|
34 |
+
|
35 |
+
|
36 |
+
def clear_data():
|
37 |
+
return {prompt: None, image: None, audio_upload: None, audio_mic: None, chatbot: []}
|
38 |
+
|
39 |
+
|
40 |
+
def run(history, text, image, audio_upload, audio_mic):
|
41 |
+
if text in [None, ""]:
|
42 |
+
text = None
|
43 |
+
|
44 |
+
if audio_upload is not None:
|
45 |
+
audio = audio_upload
|
46 |
+
elif audio_mic is not None:
|
47 |
+
audio = audio_mic
|
48 |
+
else:
|
49 |
+
audio = None
|
50 |
+
|
51 |
+
print("text", text)
|
52 |
+
print("image", image)
|
53 |
+
print("audio", audio)
|
54 |
+
|
55 |
+
if image is not None:
|
56 |
+
image = Image.open(image)
|
57 |
+
outputs = multimodal_phi2(text, audio, image)
|
58 |
+
# outputs = ""
|
59 |
+
|
60 |
+
history.append((None, outputs.title()))
|
61 |
+
return history, None, None, None, None
|
62 |
+
|
63 |
+
|
64 |
+
with gr.Blocks() as demo:
|
65 |
+
gr.Markdown("## MulitModal Phi2 Model Pretraining and Finetuning from Scratch")
|
66 |
+
gr.Markdown(
|
67 |
+
"""This is a multimodal implementation of [Phi2](https://huggingface.co/microsoft/phi-2) model.
|
68 |
+
|
69 |
+
Please find the source code and training details [here](https://github.com/RaviNaik/ERA-CAPSTONE/MultiModalPhi2).
|
70 |
+
|
71 |
+
### Details:
|
72 |
+
1. LLM Backbone: [Phi2](https://huggingface.co/microsoft/phi-2)
|
73 |
+
2. Vision Tower: [clip-vit-large-patch14-336](https://huggingface.co/openai/clip-vit-large-patch14-336)
|
74 |
+
3. Audio Model: [Whisper Tiny](https://huggingface.co/openai/whisper-tiny)
|
75 |
+
4. Pretraining Dataset: [LAION-CC-SBU dataset with BLIP captions(200k samples)](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain)
|
76 |
+
5. Finetuning Dataset: [Instruct 150k dataset based on COCO](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)
|
77 |
+
6. Finetuned Model: [RaviNaik/Llava-Phi2](https://huggingface.co/RaviNaik/Llava-Phi2)
|
78 |
+
"""
|
79 |
+
)
|
80 |
+
with gr.Row():
|
81 |
+
with gr.Column(scale=4):
|
82 |
+
# Creating a column with a scale of 6
|
83 |
+
with gr.Box():
|
84 |
+
with gr.Row():
|
85 |
+
# Adding a Textbox with a placeholder "write prompt"
|
86 |
+
prompt = gr.Textbox(
|
87 |
+
placeholder="Enter Prompt", lines=2, label="Query", value=None
|
88 |
+
)
|
89 |
+
# Creating a column with a scale of 2
|
90 |
+
with gr.Row():
|
91 |
+
# Adding image
|
92 |
+
image = gr.Image(type="filepath", value=None)
|
93 |
+
# Creating a column with a scale of 2
|
94 |
+
with gr.Row():
|
95 |
+
# Add audio
|
96 |
+
audio_upload = gr.Audio(source="upload", type="filepath")
|
97 |
+
audio_mic = gr.Audio(
|
98 |
+
source="microphone", type="filepath", format="mp3"
|
99 |
+
)
|
100 |
+
|
101 |
+
with gr.Column(scale=8):
|
102 |
+
with gr.Box():
|
103 |
+
with gr.Row():
|
104 |
+
chatbot = gr.Chatbot(
|
105 |
+
avatar_images=("🧑", "🤖"),
|
106 |
+
height=550,
|
107 |
+
)
|
108 |
+
with gr.Row():
|
109 |
+
# Adding a Button
|
110 |
+
submit = gr.Button()
|
111 |
+
clear = gr.Button(value="Clear")
|
112 |
+
|
113 |
+
submit.click(
|
114 |
+
add_content,
|
115 |
+
inputs=[chatbot, prompt, image, audio_upload, audio_mic],
|
116 |
+
outputs=[chatbot],
|
117 |
+
).success(
|
118 |
+
run,
|
119 |
+
inputs=[chatbot, prompt, image, audio_upload, audio_mic],
|
120 |
+
outputs=[chatbot, prompt, image, audio_upload, audio_mic],
|
121 |
+
)
|
122 |
+
|
123 |
+
clear.click(
|
124 |
+
clear_data,
|
125 |
+
outputs=[prompt, image, audio_upload, audio_mic, chatbot],
|
126 |
+
)
|
127 |
+
|
128 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.6.1
|
2 |
+
einops-exts==0.0.4
|
3 |
+
timm==0.6.13
|
4 |
+
gradio==3.35.2
|
5 |
+
gradio_client==0.2.9
|
6 |
+
markdown2[all]
|
7 |
+
numpy
|
8 |
+
requests
|
9 |
+
tokenizers==0.15.0
|
10 |
+
torch==2.0.1
|
11 |
+
shortuuid
|
12 |
+
httpx==0.24.0
|
13 |
+
deepspeed==0.9.5
|
14 |
+
peft==0.4.0
|
15 |
+
transformers==4.36.2
|
16 |
+
accelerate==0.21.0
|
17 |
+
bitsandbytes==0.41.0
|
18 |
+
scikit-learn==1.2.2
|
19 |
+
sentencepiece==0.1.99
|
20 |
+
librosa
|
21 |
+
soundfile
|