hpc-yekin
commited on
Commit
•
92e0882
1
Parent(s):
4aa0b3a
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- AlphaCLIP/.gitignore +12 -0
- AlphaCLIP/LICENSE +201 -0
- AlphaCLIP/MANIFEST.in +1 -0
- AlphaCLIP/alpha_clip/__init__.py +1 -0
- AlphaCLIP/alpha_clip/alpha_clip.py +250 -0
- AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- AlphaCLIP/alpha_clip/model.py +598 -0
- AlphaCLIP/alpha_clip/simple_tokenizer.py +132 -0
- AlphaCLIP/eval/README.md +6 -0
- AlphaCLIP/eval/imagenet_s_zs_test/.gitignore +2 -0
- AlphaCLIP/eval/imagenet_s_zs_test/README.md +21 -0
- AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py +149 -0
- AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py +66 -0
- AlphaCLIP/eval/rec_zs_test/LICENSE.md +201 -0
- AlphaCLIP/eval/rec_zs_test/README.md +74 -0
- AlphaCLIP/eval/rec_zs_test/cache/.gitkeep +0 -0
- AlphaCLIP/eval/rec_zs_test/cal_acc.py +21 -0
- AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep +0 -0
- AlphaCLIP/eval/rec_zs_test/data/.gitkeep +0 -0
- AlphaCLIP/eval/rec_zs_test/entity_extraction.py +142 -0
- AlphaCLIP/eval/rec_zs_test/executor.py +401 -0
- AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py +107 -0
- AlphaCLIP/eval/rec_zs_test/heuristics.py +68 -0
- AlphaCLIP/eval/rec_zs_test/interpreter.py +212 -0
- AlphaCLIP/eval/rec_zs_test/lattice.py +70 -0
- AlphaCLIP/eval/rec_zs_test/main.py +200 -0
- AlphaCLIP/eval/rec_zs_test/methods/__init__.py +3 -0
- AlphaCLIP/eval/rec_zs_test/methods/baseline.py +57 -0
- AlphaCLIP/eval/rec_zs_test/methods/parse.py +239 -0
- AlphaCLIP/eval/rec_zs_test/methods/random_method.py +30 -0
- AlphaCLIP/eval/rec_zs_test/methods/ref_method.py +13 -0
- AlphaCLIP/eval/rec_zs_test/output/.gitkeep +0 -0
- AlphaCLIP/eval/rec_zs_test/requirements.txt +53 -0
- AlphaCLIP/eval/rec_zs_test/run.sh +1 -0
- AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh +15 -0
- AlphaCLIP/hubconf.py +42 -0
- AlphaCLIP/requirements.txt +5 -0
- AlphaCLIP/setup.py +21 -0
- README.md +1 -1
- app.py +113 -0
- clip_l14_grit+mim_fultune_6xe.pth +3 -0
- config/inference_config.yaml +16 -0
- image_encoder/config.json +23 -0
- image_encoder/pytorch_model.bin +3 -0
- ip-adapter_sd15.bin +3 -0
- model.safetensors +3 -0
- model/__init__.py +5 -0
- model/attention_processor.py +189 -0
- model/clip_away.py +280 -0
- model/resampler.py +158 -0
AlphaCLIP/.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
*.egg-info
|
5 |
+
.pytest_cache
|
6 |
+
.ipynb_checkpoints
|
7 |
+
|
8 |
+
thumbs.db
|
9 |
+
.DS_Store
|
10 |
+
.idea
|
11 |
+
checkpoints/*
|
12 |
+
*.pth
|
AlphaCLIP/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [Zeyi Sun] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
AlphaCLIP/MANIFEST.in
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
include alpha_clip/bpe_simple_vocab_16e6.txt.gz
|
AlphaCLIP/alpha_clip/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .alpha_clip import *
|
AlphaCLIP/alpha_clip/alpha_clip.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from typing import Any, Union, List
|
6 |
+
from pkg_resources import packaging
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from .model import build_model
|
14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
15 |
+
|
16 |
+
try:
|
17 |
+
from torchvision.transforms import InterpolationMode
|
18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
19 |
+
except ImportError:
|
20 |
+
BICUBIC = Image.BICUBIC
|
21 |
+
|
22 |
+
|
23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
25 |
+
|
26 |
+
|
27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
28 |
+
_tokenizer = _Tokenizer()
|
29 |
+
|
30 |
+
_MODELS = {
|
31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def _download(url: str, root: str):
|
44 |
+
os.makedirs(root, exist_ok=True)
|
45 |
+
filename = os.path.basename(url)
|
46 |
+
|
47 |
+
expected_sha256 = url.split("/")[-2]
|
48 |
+
download_target = os.path.join(root, filename)
|
49 |
+
|
50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
52 |
+
|
53 |
+
if os.path.isfile(download_target):
|
54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
55 |
+
return download_target
|
56 |
+
else:
|
57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
58 |
+
|
59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
|
69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
71 |
+
|
72 |
+
return download_target
|
73 |
+
|
74 |
+
|
75 |
+
def _convert_image_to_rgb(image):
|
76 |
+
return image.convert("RGB")
|
77 |
+
|
78 |
+
|
79 |
+
def _transform(n_px):
|
80 |
+
return Compose([
|
81 |
+
Resize(n_px, interpolation=BICUBIC),
|
82 |
+
CenterCrop(n_px),
|
83 |
+
_convert_image_to_rgb,
|
84 |
+
ToTensor(),
|
85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
86 |
+
])
|
87 |
+
|
88 |
+
|
89 |
+
def available_models() -> List[str]:
|
90 |
+
"""Returns the names of available CLIP models"""
|
91 |
+
return list(_MODELS.keys())
|
92 |
+
|
93 |
+
|
94 |
+
def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
|
95 |
+
"""Load a CLIP model
|
96 |
+
|
97 |
+
Parameters
|
98 |
+
----------
|
99 |
+
name : str
|
100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
101 |
+
|
102 |
+
alpha_vision_ckpt_pth: str
|
103 |
+
only changed when inferencing model instead of training
|
104 |
+
|
105 |
+
device : Union[str, torch.device]
|
106 |
+
The device to put the loaded model
|
107 |
+
|
108 |
+
jit : bool
|
109 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
110 |
+
|
111 |
+
download_root: str
|
112 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
113 |
+
|
114 |
+
Returns
|
115 |
+
-------
|
116 |
+
model : torch.nn.Module
|
117 |
+
The CLIP model
|
118 |
+
|
119 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
120 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
121 |
+
"""
|
122 |
+
if name in _MODELS:
|
123 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
124 |
+
elif os.path.isfile(name):
|
125 |
+
model_path = name
|
126 |
+
else:
|
127 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
128 |
+
|
129 |
+
with open(model_path, 'rb') as opened_file:
|
130 |
+
try:
|
131 |
+
# loading JIT archive
|
132 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
133 |
+
state_dict = None
|
134 |
+
except RuntimeError:
|
135 |
+
# loading saved state dict
|
136 |
+
if jit:
|
137 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
138 |
+
jit = False
|
139 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
140 |
+
|
141 |
+
if not jit:
|
142 |
+
model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
|
143 |
+
if str(device) == "cpu":
|
144 |
+
model.float()
|
145 |
+
if alpha_vision_ckpt_pth != "None":
|
146 |
+
model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
|
147 |
+
model.eval() # merge lora params if exists (for inference only)
|
148 |
+
return model, _transform(model.visual.input_resolution)
|
149 |
+
|
150 |
+
# patch the device names
|
151 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
152 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
153 |
+
|
154 |
+
def _node_get(node: torch._C.Node, key: str):
|
155 |
+
"""Gets attributes of a node which is polymorphic over return type.
|
156 |
+
|
157 |
+
From https://github.com/pytorch/pytorch/pull/82628
|
158 |
+
"""
|
159 |
+
sel = node.kindOf(key)
|
160 |
+
return getattr(node, sel)(key)
|
161 |
+
|
162 |
+
def patch_device(module):
|
163 |
+
try:
|
164 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
165 |
+
except RuntimeError:
|
166 |
+
graphs = []
|
167 |
+
|
168 |
+
if hasattr(module, "forward1"):
|
169 |
+
graphs.append(module.forward1.graph)
|
170 |
+
|
171 |
+
for graph in graphs:
|
172 |
+
for node in graph.findAllNodes("prim::Constant"):
|
173 |
+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
174 |
+
node.copyAttributes(device_node)
|
175 |
+
|
176 |
+
model.apply(patch_device)
|
177 |
+
patch_device(model.encode_image)
|
178 |
+
patch_device(model.encode_text)
|
179 |
+
|
180 |
+
# patch dtype to float32 on CPU
|
181 |
+
if str(device) == "cpu":
|
182 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
183 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
184 |
+
float_node = float_input.node()
|
185 |
+
|
186 |
+
def patch_float(module):
|
187 |
+
try:
|
188 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
189 |
+
except RuntimeError:
|
190 |
+
graphs = []
|
191 |
+
|
192 |
+
if hasattr(module, "forward1"):
|
193 |
+
graphs.append(module.forward1.graph)
|
194 |
+
|
195 |
+
for graph in graphs:
|
196 |
+
for node in graph.findAllNodes("aten::to"):
|
197 |
+
inputs = list(node.inputs())
|
198 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
199 |
+
if _node_get(inputs[i].node(), "value") == 5:
|
200 |
+
inputs[i].node().copyAttributes(float_node)
|
201 |
+
|
202 |
+
model.apply(patch_float)
|
203 |
+
patch_float(model.encode_image)
|
204 |
+
patch_float(model.encode_text)
|
205 |
+
|
206 |
+
model.float()
|
207 |
+
return model, _transform(model.input_resolution.item())
|
208 |
+
|
209 |
+
|
210 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
|
211 |
+
"""
|
212 |
+
Returns the tokenized representation of given input string(s)
|
213 |
+
|
214 |
+
Parameters
|
215 |
+
----------
|
216 |
+
texts : Union[str, List[str]]
|
217 |
+
An input string or a list of input strings to tokenize
|
218 |
+
|
219 |
+
context_length : int
|
220 |
+
The context length to use; all CLIP models use 77 as the context length
|
221 |
+
|
222 |
+
truncate: bool
|
223 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
224 |
+
|
225 |
+
Returns
|
226 |
+
-------
|
227 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
228 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
229 |
+
"""
|
230 |
+
if isinstance(texts, str):
|
231 |
+
texts = [texts]
|
232 |
+
|
233 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
234 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
235 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
236 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
237 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
238 |
+
else:
|
239 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
240 |
+
|
241 |
+
for i, tokens in enumerate(all_tokens):
|
242 |
+
if len(tokens) > context_length:
|
243 |
+
if truncate:
|
244 |
+
tokens = tokens[:context_length]
|
245 |
+
tokens[-1] = eot_token
|
246 |
+
else:
|
247 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
248 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
249 |
+
|
250 |
+
return result
|
AlphaCLIP/alpha_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
AlphaCLIP/alpha_clip/model.py
ADDED
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
from typing import Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import nn
|
8 |
+
import loralib as lora
|
9 |
+
import math
|
10 |
+
import collections
|
11 |
+
|
12 |
+
class Bottleneck(nn.Module):
|
13 |
+
expansion = 4
|
14 |
+
|
15 |
+
def __init__(self, inplanes, planes, stride=1):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
19 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
20 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
21 |
+
self.relu1 = nn.ReLU(inplace=True)
|
22 |
+
|
23 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
24 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
25 |
+
self.relu2 = nn.ReLU(inplace=True)
|
26 |
+
|
27 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
28 |
+
|
29 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
30 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
31 |
+
self.relu3 = nn.ReLU(inplace=True)
|
32 |
+
|
33 |
+
self.downsample = None
|
34 |
+
self.stride = stride
|
35 |
+
|
36 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
37 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
38 |
+
self.downsample = nn.Sequential(OrderedDict([
|
39 |
+
("-1", nn.AvgPool2d(stride)),
|
40 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
41 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
42 |
+
]))
|
43 |
+
|
44 |
+
def forward(self, x: torch.Tensor):
|
45 |
+
identity = x
|
46 |
+
|
47 |
+
out = self.relu1(self.bn1(self.conv1(x)))
|
48 |
+
out = self.relu2(self.bn2(self.conv2(out)))
|
49 |
+
out = self.avgpool(out)
|
50 |
+
out = self.bn3(self.conv3(out))
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
identity = self.downsample(x)
|
54 |
+
|
55 |
+
out += identity
|
56 |
+
out = self.relu3(out)
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class AttentionPool2d(nn.Module):
|
61 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
62 |
+
super().__init__()
|
63 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
64 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
66 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
67 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
68 |
+
self.num_heads = num_heads
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
72 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
73 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
74 |
+
x, _ = F.multi_head_attention_forward(
|
75 |
+
query=x[:1], key=x, value=x,
|
76 |
+
embed_dim_to_check=x.shape[-1],
|
77 |
+
num_heads=self.num_heads,
|
78 |
+
q_proj_weight=self.q_proj.weight,
|
79 |
+
k_proj_weight=self.k_proj.weight,
|
80 |
+
v_proj_weight=self.v_proj.weight,
|
81 |
+
in_proj_weight=None,
|
82 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
83 |
+
bias_k=None,
|
84 |
+
bias_v=None,
|
85 |
+
add_zero_attn=False,
|
86 |
+
dropout_p=0,
|
87 |
+
out_proj_weight=self.c_proj.weight,
|
88 |
+
out_proj_bias=self.c_proj.bias,
|
89 |
+
use_separate_proj_weight=True,
|
90 |
+
training=self.training,
|
91 |
+
need_weights=False
|
92 |
+
)
|
93 |
+
return x.squeeze(0)
|
94 |
+
|
95 |
+
|
96 |
+
class ModifiedResNet(nn.Module):
|
97 |
+
"""
|
98 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
99 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
100 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
101 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
102 |
+
"""
|
103 |
+
|
104 |
+
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
105 |
+
super().__init__()
|
106 |
+
self.output_dim = output_dim
|
107 |
+
self.input_resolution = input_resolution
|
108 |
+
|
109 |
+
# the 3-layer stem
|
110 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
111 |
+
self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
112 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
113 |
+
self.relu1 = nn.ReLU(inplace=True)
|
114 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
115 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
116 |
+
self.relu2 = nn.ReLU(inplace=True)
|
117 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
118 |
+
self.bn3 = nn.BatchNorm2d(width)
|
119 |
+
self.relu3 = nn.ReLU(inplace=True)
|
120 |
+
self.avgpool = nn.AvgPool2d(2)
|
121 |
+
|
122 |
+
# residual layers
|
123 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
124 |
+
self.layer1 = self._make_layer(width, layers[0])
|
125 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
126 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
127 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
128 |
+
|
129 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
130 |
+
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
131 |
+
|
132 |
+
def _make_layer(self, planes, blocks, stride=1):
|
133 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
134 |
+
|
135 |
+
self._inplanes = planes * Bottleneck.expansion
|
136 |
+
for _ in range(1, blocks):
|
137 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
138 |
+
|
139 |
+
return nn.Sequential(*layers)
|
140 |
+
|
141 |
+
def forward(self, x, alpha=None):
|
142 |
+
def stem(x):
|
143 |
+
x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
|
144 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
145 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
146 |
+
x = self.avgpool(x)
|
147 |
+
return x
|
148 |
+
|
149 |
+
x = x.type(self.conv1.weight.dtype)
|
150 |
+
x = stem(x)
|
151 |
+
x = self.layer1(x)
|
152 |
+
x = self.layer2(x)
|
153 |
+
x = self.layer3(x)
|
154 |
+
x = self.layer4(x)
|
155 |
+
x = self.attnpool(x)
|
156 |
+
|
157 |
+
return x
|
158 |
+
|
159 |
+
|
160 |
+
class LayerNorm(nn.LayerNorm):
|
161 |
+
"""Subclass torch's LayerNorm to handle fp16."""
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor):
|
164 |
+
orig_type = x.dtype
|
165 |
+
ret = super().forward(x.type(torch.float32))
|
166 |
+
return ret.type(orig_type)
|
167 |
+
|
168 |
+
|
169 |
+
class QuickGELU(nn.Module):
|
170 |
+
def forward(self, x: torch.Tensor):
|
171 |
+
return x * torch.sigmoid(1.702 * x)
|
172 |
+
|
173 |
+
class Attention(nn.Module):
|
174 |
+
def __init__(
|
175 |
+
self,
|
176 |
+
dim,
|
177 |
+
num_heads=8,
|
178 |
+
qkv_bias=True,
|
179 |
+
scaled_cosine=False,
|
180 |
+
scale_heads=False,
|
181 |
+
logit_scale_max=math.log(1. / 0.01),
|
182 |
+
attn_drop=0.,
|
183 |
+
proj_drop=0.,
|
184 |
+
lora_adapt=False,
|
185 |
+
rank=16
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
self.scaled_cosine = scaled_cosine
|
189 |
+
self.scale_heads = scale_heads
|
190 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
191 |
+
self.num_heads = num_heads
|
192 |
+
self.head_dim = dim // num_heads
|
193 |
+
self.scale = self.head_dim ** -0.5
|
194 |
+
self.logit_scale_max = logit_scale_max
|
195 |
+
|
196 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
197 |
+
if lora_adapt:
|
198 |
+
print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
|
199 |
+
self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
|
200 |
+
else:
|
201 |
+
self.in_proj = nn.Linear(dim, dim * 3)
|
202 |
+
# self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
203 |
+
# if qkv_bias:
|
204 |
+
# self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
205 |
+
# else:
|
206 |
+
# self.in_proj_bias = None
|
207 |
+
|
208 |
+
if self.scaled_cosine:
|
209 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
210 |
+
else:
|
211 |
+
self.logit_scale = None
|
212 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
213 |
+
if self.scale_heads:
|
214 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
215 |
+
else:
|
216 |
+
self.head_scale = None
|
217 |
+
self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
|
218 |
+
self.out_drop = nn.Dropout(proj_drop)
|
219 |
+
|
220 |
+
def forward(self, x, attn_mask = None):
|
221 |
+
L, N, C = x.shape
|
222 |
+
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
223 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
224 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
225 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
226 |
+
|
227 |
+
if self.logit_scale is not None:
|
228 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
229 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
230 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
231 |
+
attn = attn.view(-1, L, L)
|
232 |
+
else:
|
233 |
+
q = q * self.scale
|
234 |
+
attn = torch.bmm(q, k.transpose(-2, -1))
|
235 |
+
|
236 |
+
if attn_mask is not None:
|
237 |
+
if attn_mask.dtype == torch.bool:
|
238 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
239 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
240 |
+
attn_mask = new_attn_mask
|
241 |
+
attn += attn_mask
|
242 |
+
|
243 |
+
attn = attn.softmax(dim=-1)
|
244 |
+
attn = self.attn_drop(attn)
|
245 |
+
|
246 |
+
x = torch.bmm(attn, v)
|
247 |
+
if self.head_scale is not None:
|
248 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
249 |
+
x = x.view(-1, L, C)
|
250 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
251 |
+
x = self.out_proj(x)
|
252 |
+
x = self.out_drop(x)
|
253 |
+
return x, attn
|
254 |
+
|
255 |
+
|
256 |
+
class CustomResidualAttentionBlock(nn.Module):
|
257 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank)
|
261 |
+
self.ln_1 = LayerNorm(d_model)
|
262 |
+
self.mlp = nn.Sequential(OrderedDict([
|
263 |
+
("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
|
264 |
+
("gelu", QuickGELU()),
|
265 |
+
("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
|
266 |
+
]))
|
267 |
+
self.ln_2 = LayerNorm(d_model)
|
268 |
+
self.attn_mask = attn_mask
|
269 |
+
|
270 |
+
def attention(self, x: torch.Tensor):
|
271 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
272 |
+
return self.attn(x, attn_mask=self.attn_mask)
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor, return_attn=False):
|
275 |
+
attn_out, attn = self.attention(self.ln_1(x))
|
276 |
+
x = x + attn_out
|
277 |
+
x = x + self.mlp(self.ln_2(x))
|
278 |
+
if return_attn:
|
279 |
+
return x, attn
|
280 |
+
else:
|
281 |
+
return x
|
282 |
+
|
283 |
+
class ResidualAttentionBlock(nn.Module):
|
284 |
+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
285 |
+
super().__init__()
|
286 |
+
|
287 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
288 |
+
self.ln_1 = LayerNorm(d_model)
|
289 |
+
self.mlp = nn.Sequential(OrderedDict([
|
290 |
+
("c_fc", nn.Linear(d_model, d_model * 4)),
|
291 |
+
("gelu", QuickGELU()),
|
292 |
+
("c_proj", nn.Linear(d_model * 4, d_model))
|
293 |
+
]))
|
294 |
+
self.ln_2 = LayerNorm(d_model)
|
295 |
+
self.attn_mask = attn_mask
|
296 |
+
|
297 |
+
def attention(self, x: torch.Tensor):
|
298 |
+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
299 |
+
return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
|
300 |
+
|
301 |
+
def forward(self, x: torch.Tensor):
|
302 |
+
x = x + self.attention(self.ln_1(x))
|
303 |
+
x = x + self.mlp(self.ln_2(x))
|
304 |
+
return x
|
305 |
+
|
306 |
+
class Transformer(nn.Module):
|
307 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
308 |
+
super().__init__()
|
309 |
+
self.width = width
|
310 |
+
self.layers = layers
|
311 |
+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
312 |
+
|
313 |
+
def forward(self, x: torch.Tensor):
|
314 |
+
return self.resblocks(x)
|
315 |
+
|
316 |
+
class CustomTransformer(nn.Module):
|
317 |
+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
|
318 |
+
super().__init__()
|
319 |
+
self.width = width
|
320 |
+
self.layers = layers
|
321 |
+
self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)])
|
322 |
+
|
323 |
+
def forward(self, x: torch.Tensor, return_attn=False):
|
324 |
+
if return_attn:
|
325 |
+
for i, block in enumerate(self.resblocks):
|
326 |
+
if i == len(self.resblocks) - 1:
|
327 |
+
return block(x, return_attn=True)
|
328 |
+
else:
|
329 |
+
x = block(x)
|
330 |
+
assert False
|
331 |
+
return self.resblocks(x)
|
332 |
+
|
333 |
+
class VisionTransformer(nn.Module):
|
334 |
+
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
|
335 |
+
super().__init__()
|
336 |
+
self.input_resolution = input_resolution
|
337 |
+
self.output_dim = output_dim
|
338 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
339 |
+
self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
340 |
+
|
341 |
+
scale = width ** -0.5
|
342 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
343 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
344 |
+
self.ln_pre = LayerNorm(width)
|
345 |
+
|
346 |
+
self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank)
|
347 |
+
|
348 |
+
self.ln_post = LayerNorm(width)
|
349 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
350 |
+
|
351 |
+
def forward(self, x: torch.Tensor, alpha=None, return_attn=False):
|
352 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
353 |
+
# ASSUME alpha is always not None!
|
354 |
+
x = x + self.conv1_alpha(alpha)
|
355 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
356 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
357 |
+
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
358 |
+
x = x + self.positional_embedding.to(x.dtype)
|
359 |
+
x = self.ln_pre(x)
|
360 |
+
|
361 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
362 |
+
if return_attn:
|
363 |
+
x, attn_last = self.transformer(x, return_attn=True)
|
364 |
+
else:
|
365 |
+
x = self.transformer(x, return_attn=False)
|
366 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
367 |
+
|
368 |
+
x = self.ln_post(x[:, 0, :])
|
369 |
+
|
370 |
+
if self.proj is not None:
|
371 |
+
x = x @ self.proj
|
372 |
+
if return_attn:
|
373 |
+
return x, attn_last
|
374 |
+
else:
|
375 |
+
return x
|
376 |
+
|
377 |
+
|
378 |
+
class CLIP(nn.Module):
|
379 |
+
def __init__(self,
|
380 |
+
embed_dim: int,
|
381 |
+
# vision
|
382 |
+
image_resolution: int,
|
383 |
+
vision_layers: Union[Tuple[int, int, int, int], int],
|
384 |
+
vision_width: int,
|
385 |
+
vision_patch_size: int,
|
386 |
+
# text
|
387 |
+
context_length: int,
|
388 |
+
vocab_size: int,
|
389 |
+
transformer_width: int,
|
390 |
+
transformer_heads: int,
|
391 |
+
transformer_layers: int,
|
392 |
+
lora_adapt = False,
|
393 |
+
rank = 16,
|
394 |
+
):
|
395 |
+
super().__init__()
|
396 |
+
|
397 |
+
self.context_length = context_length
|
398 |
+
|
399 |
+
if isinstance(vision_layers, (tuple, list)):
|
400 |
+
vision_heads = vision_width * 32 // 64
|
401 |
+
self.visual = ModifiedResNet(
|
402 |
+
layers=vision_layers,
|
403 |
+
output_dim=embed_dim,
|
404 |
+
heads=vision_heads,
|
405 |
+
input_resolution=image_resolution,
|
406 |
+
width=vision_width
|
407 |
+
)
|
408 |
+
else:
|
409 |
+
vision_heads = vision_width // 64
|
410 |
+
self.visual = VisionTransformer(
|
411 |
+
input_resolution=image_resolution,
|
412 |
+
patch_size=vision_patch_size,
|
413 |
+
width=vision_width,
|
414 |
+
layers=vision_layers,
|
415 |
+
heads=vision_heads,
|
416 |
+
output_dim=embed_dim,
|
417 |
+
lora_adapt=lora_adapt,
|
418 |
+
rank=rank
|
419 |
+
)
|
420 |
+
|
421 |
+
self.transformer = Transformer(
|
422 |
+
width=transformer_width,
|
423 |
+
layers=transformer_layers,
|
424 |
+
heads=transformer_heads,
|
425 |
+
attn_mask=self.build_attention_mask()
|
426 |
+
)
|
427 |
+
|
428 |
+
self.vocab_size = vocab_size
|
429 |
+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
430 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
431 |
+
self.ln_final = LayerNorm(transformer_width)
|
432 |
+
|
433 |
+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
434 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
435 |
+
|
436 |
+
self.initialize_parameters()
|
437 |
+
|
438 |
+
def initialize_parameters(self):
|
439 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
440 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
441 |
+
|
442 |
+
if isinstance(self.visual, ModifiedResNet):
|
443 |
+
if self.visual.attnpool is not None:
|
444 |
+
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
445 |
+
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
446 |
+
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
447 |
+
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
448 |
+
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
449 |
+
|
450 |
+
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
451 |
+
for name, param in resnet_block.named_parameters():
|
452 |
+
if name.endswith("bn3.weight"):
|
453 |
+
nn.init.zeros_(param)
|
454 |
+
|
455 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
456 |
+
attn_std = self.transformer.width ** -0.5
|
457 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
458 |
+
for block in self.transformer.resblocks:
|
459 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
460 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
461 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
462 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
463 |
+
|
464 |
+
if self.text_projection is not None:
|
465 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
466 |
+
|
467 |
+
def build_attention_mask(self):
|
468 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
469 |
+
# pytorch uses additive attention mask; fill with -inf
|
470 |
+
mask = torch.empty(self.context_length, self.context_length)
|
471 |
+
mask.fill_(float("-inf"))
|
472 |
+
mask.triu_(1) # zero out the lower diagonal
|
473 |
+
return mask
|
474 |
+
|
475 |
+
@property
|
476 |
+
def dtype(self):
|
477 |
+
if not hasattr(self.visual, "conv1"):
|
478 |
+
return self.visual.module.conv1.weight.dtype
|
479 |
+
return self.visual.conv1.weight.dtype
|
480 |
+
|
481 |
+
def encode_image(self, image, alpha):
|
482 |
+
assert alpha is not None
|
483 |
+
return self.visual(image.type(self.dtype), alpha.type(self.dtype))
|
484 |
+
|
485 |
+
def encode_text(self, text):
|
486 |
+
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
487 |
+
|
488 |
+
x = x + self.positional_embedding.type(self.dtype)
|
489 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
490 |
+
x = self.transformer(x)
|
491 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
492 |
+
x = self.ln_final(x).type(self.dtype)
|
493 |
+
|
494 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
495 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
496 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
497 |
+
|
498 |
+
return x
|
499 |
+
|
500 |
+
def forward(self, image, text, alpha):
|
501 |
+
image_features = self.encode_image(image, alpha)
|
502 |
+
text_features = self.encode_text(text)
|
503 |
+
|
504 |
+
# normalized features
|
505 |
+
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
506 |
+
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
507 |
+
|
508 |
+
# cosine similarity as logits
|
509 |
+
logit_scale = self.logit_scale.exp()
|
510 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
511 |
+
logits_per_text = logits_per_image.t()
|
512 |
+
|
513 |
+
# shape = [global_batch_size, global_batch_size]
|
514 |
+
return logits_per_image, logits_per_text
|
515 |
+
|
516 |
+
|
517 |
+
def convert_weights(model: nn.Module):
|
518 |
+
"""Convert applicable model parameters to fp16"""
|
519 |
+
|
520 |
+
def _convert_weights_to_fp16(l):
|
521 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
522 |
+
l.weight.data = l.weight.data.half()
|
523 |
+
if l.bias is not None:
|
524 |
+
l.bias.data = l.bias.data.half()
|
525 |
+
|
526 |
+
if isinstance(l, nn.MultiheadAttention):
|
527 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
528 |
+
tensor = getattr(l, attr)
|
529 |
+
if tensor is not None:
|
530 |
+
tensor.data = tensor.data.half()
|
531 |
+
|
532 |
+
for name in ["text_projection", "proj"]:
|
533 |
+
if hasattr(l, name):
|
534 |
+
attr = getattr(l, name)
|
535 |
+
if attr is not None:
|
536 |
+
attr.data = attr.data.half()
|
537 |
+
|
538 |
+
model.apply(_convert_weights_to_fp16)
|
539 |
+
|
540 |
+
|
541 |
+
def build_model(state_dict: dict, lora_adapt=False, rank=16):
|
542 |
+
vit = "visual.proj" in state_dict
|
543 |
+
|
544 |
+
if vit:
|
545 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
546 |
+
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
547 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
548 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
549 |
+
image_resolution = vision_patch_size * grid_size
|
550 |
+
else:
|
551 |
+
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
552 |
+
vision_layers = tuple(counts)
|
553 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
554 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
555 |
+
vision_patch_size = None
|
556 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
557 |
+
image_resolution = output_width * 32
|
558 |
+
|
559 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
560 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
561 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
562 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
563 |
+
transformer_heads = transformer_width // 64
|
564 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
565 |
+
|
566 |
+
# always load lora version
|
567 |
+
model = CLIP(
|
568 |
+
embed_dim,
|
569 |
+
image_resolution, vision_layers, vision_width, vision_patch_size,
|
570 |
+
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
|
571 |
+
lora_adapt=lora_adapt, rank=rank,
|
572 |
+
)
|
573 |
+
|
574 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
575 |
+
if key in state_dict:
|
576 |
+
del state_dict[key]
|
577 |
+
# para_wb to linear
|
578 |
+
new_state_dict = collections.OrderedDict()
|
579 |
+
for k, v in state_dict.items():
|
580 |
+
if 'visual' in k:
|
581 |
+
if 'in_proj_weight' in k:
|
582 |
+
new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
|
583 |
+
elif 'in_proj_bias' in k:
|
584 |
+
new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
|
585 |
+
else:
|
586 |
+
new_state_dict[k] = v
|
587 |
+
else:
|
588 |
+
new_state_dict[k] = v
|
589 |
+
|
590 |
+
state_dict = new_state_dict
|
591 |
+
# add rgba_conv_weight
|
592 |
+
if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
|
593 |
+
rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
|
594 |
+
rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
|
595 |
+
state_dict['visual.conv1_alpha.weight'] = rgba_weigth
|
596 |
+
convert_weights(model)
|
597 |
+
model.load_state_dict(state_dict, strict=False)
|
598 |
+
return model.eval()
|
AlphaCLIP/alpha_clip/simple_tokenizer.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gzip
|
2 |
+
import html
|
3 |
+
import os
|
4 |
+
from functools import lru_cache
|
5 |
+
|
6 |
+
import ftfy
|
7 |
+
import regex as re
|
8 |
+
|
9 |
+
|
10 |
+
@lru_cache()
|
11 |
+
def default_bpe():
|
12 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
13 |
+
|
14 |
+
|
15 |
+
@lru_cache()
|
16 |
+
def bytes_to_unicode():
|
17 |
+
"""
|
18 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
19 |
+
The reversible bpe codes work on unicode strings.
|
20 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
21 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
22 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
23 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
24 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
25 |
+
"""
|
26 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
27 |
+
cs = bs[:]
|
28 |
+
n = 0
|
29 |
+
for b in range(2**8):
|
30 |
+
if b not in bs:
|
31 |
+
bs.append(b)
|
32 |
+
cs.append(2**8+n)
|
33 |
+
n += 1
|
34 |
+
cs = [chr(n) for n in cs]
|
35 |
+
return dict(zip(bs, cs))
|
36 |
+
|
37 |
+
|
38 |
+
def get_pairs(word):
|
39 |
+
"""Return set of symbol pairs in a word.
|
40 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
41 |
+
"""
|
42 |
+
pairs = set()
|
43 |
+
prev_char = word[0]
|
44 |
+
for char in word[1:]:
|
45 |
+
pairs.add((prev_char, char))
|
46 |
+
prev_char = char
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def basic_clean(text):
|
51 |
+
text = ftfy.fix_text(text)
|
52 |
+
text = html.unescape(html.unescape(text))
|
53 |
+
return text.strip()
|
54 |
+
|
55 |
+
|
56 |
+
def whitespace_clean(text):
|
57 |
+
text = re.sub(r'\s+', ' ', text)
|
58 |
+
text = text.strip()
|
59 |
+
return text
|
60 |
+
|
61 |
+
|
62 |
+
class SimpleTokenizer(object):
|
63 |
+
def __init__(self, bpe_path: str = default_bpe()):
|
64 |
+
self.byte_encoder = bytes_to_unicode()
|
65 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
66 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
67 |
+
merges = merges[1:49152-256-2+1]
|
68 |
+
merges = [tuple(merge.split()) for merge in merges]
|
69 |
+
vocab = list(bytes_to_unicode().values())
|
70 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
71 |
+
for merge in merges:
|
72 |
+
vocab.append(''.join(merge))
|
73 |
+
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
74 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
75 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
76 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
77 |
+
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
78 |
+
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
79 |
+
|
80 |
+
def bpe(self, token):
|
81 |
+
if token in self.cache:
|
82 |
+
return self.cache[token]
|
83 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
84 |
+
pairs = get_pairs(word)
|
85 |
+
|
86 |
+
if not pairs:
|
87 |
+
return token+'</w>'
|
88 |
+
|
89 |
+
while True:
|
90 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
91 |
+
if bigram not in self.bpe_ranks:
|
92 |
+
break
|
93 |
+
first, second = bigram
|
94 |
+
new_word = []
|
95 |
+
i = 0
|
96 |
+
while i < len(word):
|
97 |
+
try:
|
98 |
+
j = word.index(first, i)
|
99 |
+
new_word.extend(word[i:j])
|
100 |
+
i = j
|
101 |
+
except:
|
102 |
+
new_word.extend(word[i:])
|
103 |
+
break
|
104 |
+
|
105 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
106 |
+
new_word.append(first+second)
|
107 |
+
i += 2
|
108 |
+
else:
|
109 |
+
new_word.append(word[i])
|
110 |
+
i += 1
|
111 |
+
new_word = tuple(new_word)
|
112 |
+
word = new_word
|
113 |
+
if len(word) == 1:
|
114 |
+
break
|
115 |
+
else:
|
116 |
+
pairs = get_pairs(word)
|
117 |
+
word = ' '.join(word)
|
118 |
+
self.cache[token] = word
|
119 |
+
return word
|
120 |
+
|
121 |
+
def encode(self, text):
|
122 |
+
bpe_tokens = []
|
123 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
124 |
+
for token in re.findall(self.pat, text):
|
125 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
126 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
127 |
+
return bpe_tokens
|
128 |
+
|
129 |
+
def decode(self, tokens):
|
130 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
131 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
132 |
+
return text
|
AlphaCLIP/eval/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Alpha-CLIP evaluation
|
2 |
+
## Zero-Shot Classification on ImageNet-S
|
3 |
+
checkout [imagenet_s_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/imagenet_s_zs_test)
|
4 |
+
|
5 |
+
## Zero-Shot Referring Expression Comprehension on RefCOCO
|
6 |
+
checkout [rec_zs_test](https://github.com/SunzeY/AlphaCLIP/tree/eval-dev/eval/rec_zs_test)
|
AlphaCLIP/eval/imagenet_s_zs_test/.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
*.json
|
2 |
+
data/*
|
AlphaCLIP/eval/imagenet_s_zs_test/README.md
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Alpha-CLIP evaluation
|
2 |
+
## Zero-Shot Classification on ImageNet-S
|
3 |
+
|
4 |
+
1.prepare [imagenet-s](https://github.com/LUSSeg/ImageNet-S) dataset, only `validation` raw image is needed.
|
5 |
+
|
6 |
+
2.download [imagenet_919.json](https://download.openxlab.org.cn/models/SunzeY/AlphaCLIP/weight/imagenet_919.json) we provide as data annotation (generated from imagenet-s annotation). The folder should be structured like
|
7 |
+
|
8 |
+
```
|
9 |
+
├── imagenet_s_zs_test
|
10 |
+
│ ├── data
|
11 |
+
│ │ ├── imagenet_919.json
|
12 |
+
│ │ └── ImageNetS919
|
13 |
+
│ │ └── validation
|
14 |
+
```
|
15 |
+
|
16 |
+
3.run test script.
|
17 |
+
|
18 |
+
```
|
19 |
+
cd eval/imagenet_s_zs_test
|
20 |
+
python imagenet_s_zs_test.py
|
21 |
+
```
|
AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
from tqdm import tqdm
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from pycocotools.coco import COCO
|
7 |
+
from pycocotools import mask as maskUtils
|
8 |
+
from PIL import Image
|
9 |
+
import cv2
|
10 |
+
import random
|
11 |
+
from torchvision import transforms
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
import pickle
|
15 |
+
import torch
|
16 |
+
import numpy as np
|
17 |
+
import copy
|
18 |
+
import sys
|
19 |
+
import shutil
|
20 |
+
from PIL import Image
|
21 |
+
from nltk.corpus import wordnet
|
22 |
+
|
23 |
+
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
24 |
+
MASK_FILL = [int(255 * c) for c in PIXEL_MEAN]
|
25 |
+
|
26 |
+
|
27 |
+
clip_standard_transform = transforms.Compose([
|
28 |
+
transforms.ToTensor(),
|
29 |
+
transforms.Resize((224, 224), interpolation=Image.BICUBIC),
|
30 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
31 |
+
])
|
32 |
+
|
33 |
+
hi_clip_standard_transform = transforms.Compose([
|
34 |
+
transforms.ToTensor(),
|
35 |
+
transforms.Resize((336, 336), interpolation=Image.BICUBIC),
|
36 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
37 |
+
])
|
38 |
+
|
39 |
+
res_clip_standard_transform = transforms.Compose([
|
40 |
+
transforms.ToTensor(),
|
41 |
+
transforms.Resize((336, 336), interpolation=Image.BICUBIC),
|
42 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
43 |
+
])
|
44 |
+
|
45 |
+
mask_transform = transforms.Compose([
|
46 |
+
transforms.ToTensor(),
|
47 |
+
transforms.Resize((224, 224)),
|
48 |
+
transforms.Normalize(0.5, 0.26)
|
49 |
+
])
|
50 |
+
|
51 |
+
hi_mask_transform = transforms.Compose([
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Resize((336, 336)),
|
54 |
+
transforms.Normalize(0.5, 0.26)
|
55 |
+
])
|
56 |
+
|
57 |
+
res_mask_transform = transforms.Compose([
|
58 |
+
transforms.ToTensor(),
|
59 |
+
transforms.Resize((336, 336)),
|
60 |
+
transforms.Normalize(0.5, 0.26)
|
61 |
+
])
|
62 |
+
|
63 |
+
def crop_center(img, croph, cropw):
|
64 |
+
h, w = img.shape[:2]
|
65 |
+
starth = h//2 - (croph//2)
|
66 |
+
startw = w//2 - (cropw//2)
|
67 |
+
return img[starth:starth+croph, startw:startw+cropw, :]
|
68 |
+
|
69 |
+
class Imagenet_S(Dataset):
|
70 |
+
def __init__(self, ann_file='data/imagenet_919.json', hi_res=False, all_one=False):
|
71 |
+
self.anns = json.load(open(ann_file, 'r'))
|
72 |
+
self.root_pth = 'data/'
|
73 |
+
cats = []
|
74 |
+
for ann in self.anns:
|
75 |
+
if ann['category_word'] not in cats:
|
76 |
+
cats.append(ann['category_word'])
|
77 |
+
ann['cat_index'] = len(cats) - 1
|
78 |
+
self.classes = []
|
79 |
+
for cat_word in cats:
|
80 |
+
synset = wordnet.synset_from_pos_and_offset('n', int(cat_word[1:]))
|
81 |
+
synonyms = [x.name() for x in synset.lemmas()]
|
82 |
+
self.classes.append(synonyms[0])
|
83 |
+
|
84 |
+
self.choice = "center_crop"
|
85 |
+
if hi_res:
|
86 |
+
self.mask_transform = res_mask_transform
|
87 |
+
self.clip_standard_transform = res_clip_standard_transform
|
88 |
+
else:
|
89 |
+
self.mask_transform = mask_transform
|
90 |
+
self.clip_standard_transform = clip_standard_transform
|
91 |
+
|
92 |
+
self.all_one = all_one
|
93 |
+
|
94 |
+
def __len__(self):
|
95 |
+
return len(self.anns)
|
96 |
+
|
97 |
+
def __getitem__(self, index):
|
98 |
+
ann = self.anns[index]
|
99 |
+
image = cv2.imread(os.path.join(self.root_pth, ann['image_pth']))
|
100 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
101 |
+
|
102 |
+
mask = maskUtils.decode(ann['mask'])
|
103 |
+
# image[mask==0] = MASK_FILL
|
104 |
+
rgba = np.concatenate((image, np.expand_dims(mask, axis=-1)), axis=-1)
|
105 |
+
h, w = rgba.shape[:2]
|
106 |
+
|
107 |
+
if self.choice == "padding":
|
108 |
+
if max(h, w) == w:
|
109 |
+
pad = (w - h) // 2
|
110 |
+
l, r = pad, w - h - pad
|
111 |
+
rgba = np.pad(rgba, ((l, r), (0, 0), (0, 0)), 'constant', constant_values=0)
|
112 |
+
else:
|
113 |
+
pad = (h - w) // 2
|
114 |
+
l, r = pad, h - w - pad
|
115 |
+
rgba = np.pad(rgba, ((0, 0), (l, r), (0, 0)), 'constant', constant_values=0)
|
116 |
+
else:
|
117 |
+
if min(h, w) == h:
|
118 |
+
rgba = crop_center(rgba, h, h)
|
119 |
+
else:
|
120 |
+
rgba = crop_center(rgba, w, w)
|
121 |
+
rgb = rgba[:, :, :-1]
|
122 |
+
mask = rgba[:, :, -1]
|
123 |
+
image_torch = self.clip_standard_transform(rgb)
|
124 |
+
# using box: bounding-box compute
|
125 |
+
# bi_mask = mask == 1
|
126 |
+
# h, w = bi_mask.shape[-2:]
|
127 |
+
# in_height = np.max(bi_mask, axis=-1)
|
128 |
+
# in_height_coords = np.max(bi_mask, axis=-1) * np.arange(h)
|
129 |
+
# b_e = in_height_coords.max()
|
130 |
+
# in_height_coords = in_height_coords + h * (~in_height)
|
131 |
+
# t_e = in_height_coords.min()
|
132 |
+
# in_width = np.max(bi_mask, axis=-2)
|
133 |
+
# in_width_coords = np.max(bi_mask, axis=-2) * np.arange(w)
|
134 |
+
# r_e = in_width_coords.max()
|
135 |
+
# in_width_coords = in_width_coords + w * (~in_width)
|
136 |
+
# l_e = in_width_coords.min()
|
137 |
+
# box = np.zeros_like(mask)
|
138 |
+
# box[t_e: b_e, l_e:r_e] = 1
|
139 |
+
# mask = box
|
140 |
+
if self.all_one:
|
141 |
+
mask_torch = self.mask_transform(np.ones_like(mask) * 255)
|
142 |
+
else:
|
143 |
+
mask_torch = self.mask_transform(mask * 255)
|
144 |
+
return image_torch, mask_torch, ann['cat_index']
|
145 |
+
|
146 |
+
if __name__ == "__main__":
|
147 |
+
data = Imagenet_S()
|
148 |
+
for i in tqdm(range(data.__len__())):
|
149 |
+
data.__getitem__(i)
|
AlphaCLIP/eval/imagenet_s_zs_test/imagenet_s_zs_test.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import alpha_clip
|
3 |
+
from tqdm import tqdm
|
4 |
+
from imagenet_s import Imagenet_S
|
5 |
+
|
6 |
+
model, preprocess = alpha_clip.load("ViT-L/14@336px", alpha_vision_ckpt_pth="../../clip_l14@336_grit_20m_4xe.pth")
|
7 |
+
|
8 |
+
def zeroshot_classifier(classnames, templates):
|
9 |
+
with torch.no_grad():
|
10 |
+
zeroshot_weights = []
|
11 |
+
for classname in tqdm(classnames):
|
12 |
+
texts = [template.format(classname) for template in templates] #format with class
|
13 |
+
texts = alpha_clip.tokenize(texts).cuda() #tokenize
|
14 |
+
class_embeddings = model.encode_text(texts) #embed with text encoder
|
15 |
+
class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
|
16 |
+
class_embedding = class_embeddings.mean(dim=0)
|
17 |
+
class_embedding /= class_embedding.norm()
|
18 |
+
zeroshot_weights.append(class_embedding)
|
19 |
+
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
|
20 |
+
return zeroshot_weights
|
21 |
+
|
22 |
+
dataset = Imagenet_S(hi_res=True)
|
23 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=64, num_workers=2)
|
24 |
+
|
25 |
+
imagenet_templates = [
|
26 |
+
'a photo of a {}.'
|
27 |
+
]
|
28 |
+
|
29 |
+
zeroshot_weights = zeroshot_classifier(dataset.classes, imagenet_templates)
|
30 |
+
temp_corr_dict = dict()
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
for i, (images, alpha, target) in enumerate(tqdm(loader)):
|
34 |
+
images = images.cuda()
|
35 |
+
alpha = alpha.cuda()
|
36 |
+
target = target.cuda()
|
37 |
+
# predict
|
38 |
+
image_features = model.encode_image(images, alpha)
|
39 |
+
image_features /= image_features.norm(dim=-1, keepdim=True)
|
40 |
+
score = 100. * image_features @ zeroshot_weights
|
41 |
+
|
42 |
+
pred = score.topk(1, dim=1)[1].squeeze(dim=1)
|
43 |
+
pred_5 = score.topk(5, dim=1)[1].squeeze(dim=1)
|
44 |
+
|
45 |
+
for i in range(target.shape[0]):
|
46 |
+
if target[i].item() not in temp_corr_dict:
|
47 |
+
temp_corr_dict[target[i].item()] = [0, 0, 0]
|
48 |
+
temp_corr_dict[target[i].item()][0] += 1
|
49 |
+
if target[i].item() == pred[i].item():
|
50 |
+
temp_corr_dict[target[i].item()][1] += 1
|
51 |
+
if target[i].item() in pred_5[i].tolist():
|
52 |
+
temp_corr_dict[target[i].item()][2] += 1
|
53 |
+
|
54 |
+
acc1 = 0.0
|
55 |
+
acc5 = 0.0
|
56 |
+
num_class = 0
|
57 |
+
for v in temp_corr_dict.values():
|
58 |
+
if v[0] == 0: continue
|
59 |
+
acc1 += v[1] / v[0]
|
60 |
+
acc5 += v[2] / v[0]
|
61 |
+
num_class += 1
|
62 |
+
acc1 = acc1 / num_class * 100
|
63 |
+
acc5 = acc5 / num_class * 100
|
64 |
+
|
65 |
+
print(f"Top-1 accuracy: {acc1:.2f}")
|
66 |
+
print(f"Top-5 accuracy: {acc5:.2f}")
|
AlphaCLIP/eval/rec_zs_test/LICENSE.md
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
AlphaCLIP/eval/rec_zs_test/README.md
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Zero-Shot Referring Expression Comprehension on RefCOCO
|
2 |
+
|
3 |
+
**Preparing Data**
|
4 |
+
|
5 |
+
1.Download [images for RefCOCO/g/+](http://images.cocodataset.org/zips/train2014.zip). Put downloaded dataset(train2014) to eval/rec_zs_test/data/.
|
6 |
+
|
7 |
+
2.Download preprocessed data files via `gsutil cp gs://reclip-sanjays/reclip_data.tar.gz` and `cd rec_zs_test`, and then extract the data using `tar -xvzf reclip_data.tar.gz`.
|
8 |
+
|
9 |
+
**Preparing model**
|
10 |
+
|
11 |
+
3.Download [SAM](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) (vit-h), [Alpha-CLIP](https://github.com/SunzeY/AlphaCLIP/blob/main/model-zoo.md) model, and put them in ./eval/rec_zs_test/ckpt.
|
12 |
+
|
13 |
+
```
|
14 |
+
├── eval
|
15 |
+
│ ├── rec_zs_test
|
16 |
+
│ │ ├── data
|
17 |
+
│ │ └── train2014
|
18 |
+
│ │ ├── reclip_data
|
19 |
+
│ │ └── refcoco_val.jsonl
|
20 |
+
│ │ └── refcoco_dets_dict.json
|
21 |
+
│ │ ...
|
22 |
+
│ │ ├── ckpt
|
23 |
+
│ │ └── sam_vit_h_4b8939.pth
|
24 |
+
│ │ └── grit1m
|
25 |
+
│ │ └── clip_b16_grit+mim_fultune_4xe.pth
|
26 |
+
│ │ └── clip_l14_grit+mim_fultune_6xe.pth
|
27 |
+
│ │ ├── methods
|
28 |
+
│ │ ├── cache
|
29 |
+
│ │ ├── output
|
30 |
+
│ │ ├── main.py
|
31 |
+
│ │ ├── executor.py
|
32 |
+
│ │ ├── run.sh
|
33 |
+
│ │ ├── ...
|
34 |
+
```
|
35 |
+
|
36 |
+
4.run test script.
|
37 |
+
|
38 |
+
```
|
39 |
+
cd eval/rec_zs_test
|
40 |
+
```
|
41 |
+
```
|
42 |
+
bash run.sh
|
43 |
+
```
|
44 |
+
or
|
45 |
+
|
46 |
+
```
|
47 |
+
python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco+_dets_dict.json --cache_path ./cache
|
48 |
+
```
|
49 |
+
(We recommend using `cache_path` to reduce time to generate mask by SAM for a image repeatedly.`)
|
50 |
+
|
51 |
+
For multi-gpus testing, try:
|
52 |
+
|
53 |
+
```
|
54 |
+
bash run_multi_gpus.sh
|
55 |
+
python cal_acc.py refcoco_val
|
56 |
+
```
|
57 |
+
|
58 |
+
|
59 |
+
**Acknowledgement**
|
60 |
+
|
61 |
+
We test our model based on the wonderful work [ReCLIP](https://github.com/allenai/reclip/tree/main). We simply replace CLIP with Alpha-CLIP; and skip the image-cropping operation.
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
**Experiment results**
|
66 |
+
|
67 |
+
| Method | RefCOCO | | | RefCOCO+ | | | RefCOCOg | |
|
68 |
+
|----------------|---------|------|------|----------|------|------|----------|------|
|
69 |
+
| | Val | TestA| TestB| Val | TestA| TestB| Val | Test |
|
70 |
+
| CPT [67] | 32.2 | 36.1 | 30.3 | 31.9 | 35.2 | 28.8 | 36.7 | 36.5 |
|
71 |
+
| ReCLIP [54] | 45.8 | 46.1 | 47.1 | 47.9 | 50.1 | 45.1 | 59.3 | 59.0 |
|
72 |
+
| Red Circle [52]| 49.8 | 58.6 | 39.9 | 55.3 | 63.9 | 45.4 | 59.4 | 58.9 |
|
73 |
+
| Alpha-CLIP | 55.7 | 61.1 | 50.3 | 55.6 | 62.7 | 46.4 | 61.2 | 62.0 |
|
74 |
+
|
AlphaCLIP/eval/rec_zs_test/cache/.gitkeep
ADDED
File without changes
|
AlphaCLIP/eval/rec_zs_test/cal_acc.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
parser.add_argument('name', type=str, default='refcoco_val')
|
6 |
+
|
7 |
+
args = parser.parse_args()
|
8 |
+
|
9 |
+
name = args.name
|
10 |
+
print(name)
|
11 |
+
count = 0
|
12 |
+
all_count = 0
|
13 |
+
for i in range(8):
|
14 |
+
pth = f'output/{name}_count_{i}.json'
|
15 |
+
acc = json.load(open(pth, 'r'))
|
16 |
+
a_list = acc.split()
|
17 |
+
a, b = a_list[0], a_list[1]
|
18 |
+
count += int(a)
|
19 |
+
all_count += int(b)
|
20 |
+
|
21 |
+
print(float(count) / float(all_count))
|
AlphaCLIP/eval/rec_zs_test/ckpt/.gitkeep
ADDED
File without changes
|
AlphaCLIP/eval/rec_zs_test/data/.gitkeep
ADDED
File without changes
|
AlphaCLIP/eval/rec_zs_test/entity_extraction.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
|
2 |
+
import numpy as np
|
3 |
+
from spacy.tokens.token import Token
|
4 |
+
from spacy.tokens.span import Span
|
5 |
+
|
6 |
+
from lattice import Product as L
|
7 |
+
|
8 |
+
from heuristics import Heuristics
|
9 |
+
|
10 |
+
Rel = Tuple[List[Token], "Entity"]
|
11 |
+
Sup = List[Token]
|
12 |
+
|
13 |
+
DEFAULT_HEURISTICS = Heuristics()
|
14 |
+
|
15 |
+
|
16 |
+
def find_superlatives(tokens, heuristics) -> List[Sup]:
|
17 |
+
"""Modify and return a list of superlative tokens."""
|
18 |
+
for heuristic in heuristics.superlatives:
|
19 |
+
if any(tok.text in heuristic.keywords for tok in tokens):
|
20 |
+
tokens.sort(key=lambda tok: tok.i)
|
21 |
+
return [tokens]
|
22 |
+
return []
|
23 |
+
|
24 |
+
def expand_chunks(doc, chunks):
|
25 |
+
expanded = {}
|
26 |
+
for key in chunks:
|
27 |
+
chunk = chunks[key]
|
28 |
+
start = chunk.start
|
29 |
+
end = chunk.end
|
30 |
+
for i in range(chunk.start-1, -1, -1):
|
31 |
+
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
|
32 |
+
if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
|
33 |
+
start = i
|
34 |
+
for i in range(chunk.end, len(doc)):
|
35 |
+
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
|
36 |
+
if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
|
37 |
+
end = i+1
|
38 |
+
else:
|
39 |
+
break
|
40 |
+
expanded[key] = Span(doc=doc, start=start, end=end)
|
41 |
+
return expanded
|
42 |
+
|
43 |
+
class Entity(NamedTuple):
|
44 |
+
"""Represents an entity with locative constraints extracted from the parse."""
|
45 |
+
|
46 |
+
head: Span
|
47 |
+
relations: List[Rel]
|
48 |
+
superlatives: List[Sup]
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
|
52 |
+
"""Extract entities from a spacy parse.
|
53 |
+
|
54 |
+
Jointly recursive with `_get_rel_sups`."""
|
55 |
+
if heuristics is None:
|
56 |
+
heuristics = DEFAULT_HEURISTICS
|
57 |
+
|
58 |
+
if head.i not in chunks:
|
59 |
+
# Handles predicative cases.
|
60 |
+
children = list(head.children)
|
61 |
+
if children and children[0].i in chunks:
|
62 |
+
head = children[0]
|
63 |
+
# TODO: Also extract predicative relations.
|
64 |
+
else:
|
65 |
+
return None
|
66 |
+
hchunk = chunks[head.i]
|
67 |
+
rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
|
68 |
+
return cls(hchunk, rels, sups)
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
|
72 |
+
hchunk = chunks[head.i]
|
73 |
+
is_keyword = any(token.text in h.keywords for h in heuristics.relations)
|
74 |
+
is_keyword |= token.text in heuristics.null_keywords
|
75 |
+
|
76 |
+
# Found another entity head.
|
77 |
+
if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
|
78 |
+
tchunk = chunks[token.i]
|
79 |
+
tokens.sort(key=lambda tok: tok.i)
|
80 |
+
subhead = cls.extract(token, chunks, heuristics)
|
81 |
+
return [(tokens, subhead)], []
|
82 |
+
|
83 |
+
# End of a chain of modifiers.
|
84 |
+
n_children = len(list(token.children))
|
85 |
+
if n_children == 0:
|
86 |
+
return [], find_superlatives(tokens + [token], heuristics)
|
87 |
+
|
88 |
+
relations = []
|
89 |
+
superlatives = []
|
90 |
+
is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
|
91 |
+
for child in token.children:
|
92 |
+
if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
|
93 |
+
if not any(child.text in h.keywords for h in heuristics.superlatives):
|
94 |
+
if n_children == 1:
|
95 |
+
# Catches "the goat on the left"
|
96 |
+
sups = find_superlatives(tokens + [token], heuristics)
|
97 |
+
superlatives.extend(sups)
|
98 |
+
continue
|
99 |
+
new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
|
100 |
+
subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
|
101 |
+
relations.extend(subrel)
|
102 |
+
superlatives.extend(subsup)
|
103 |
+
return relations, superlatives
|
104 |
+
|
105 |
+
def expand(self, span: Span = None):
|
106 |
+
tokens = [token for token in self.head]
|
107 |
+
if span is None:
|
108 |
+
span = [None]
|
109 |
+
for target_token in span:
|
110 |
+
include = False
|
111 |
+
stack = [token for token in self.head]
|
112 |
+
while len(stack) > 0:
|
113 |
+
token = stack.pop()
|
114 |
+
if token == target_token:
|
115 |
+
token2 = target_token.head
|
116 |
+
while token2.head != token2:
|
117 |
+
tokens.append(token2)
|
118 |
+
token2 = token2.head
|
119 |
+
tokens.append(token2)
|
120 |
+
stack = []
|
121 |
+
include = True
|
122 |
+
if target_token is None or include:
|
123 |
+
tokens.append(token)
|
124 |
+
for child in token.children:
|
125 |
+
stack.append(child)
|
126 |
+
tokens = list(set(tokens))
|
127 |
+
tokens = sorted(tokens, key=lambda x: x.i)
|
128 |
+
return ' '.join([token.text for token in tokens])
|
129 |
+
|
130 |
+
def __eq__(self, other: "Entity") -> bool:
|
131 |
+
if self.text != other.text:
|
132 |
+
return False
|
133 |
+
if self.relations != other.relations:
|
134 |
+
return False
|
135 |
+
if self.superlatives != other.superlatives:
|
136 |
+
return False
|
137 |
+
return True
|
138 |
+
|
139 |
+
@property
|
140 |
+
def text(self) -> Text:
|
141 |
+
"""Get the text predicate associated with this entity."""
|
142 |
+
return self.head.text
|
AlphaCLIP/eval/rec_zs_test/executor.py
ADDED
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict, Union, Tuple
|
2 |
+
|
3 |
+
from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance
|
4 |
+
import spacy
|
5 |
+
import hashlib
|
6 |
+
import os
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torchvision
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
import clip
|
12 |
+
from transformers import BertTokenizer, RobertaTokenizerFast
|
13 |
+
import ruamel.yaml as yaml
|
14 |
+
import copy
|
15 |
+
|
16 |
+
from interpreter import Box
|
17 |
+
|
18 |
+
import pycocotools.mask as mask_utils
|
19 |
+
import alpha_clip
|
20 |
+
from segment_anything import sam_model_registry, SamPredictor
|
21 |
+
import numpy as np
|
22 |
+
import cv2
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
|
25 |
+
import pickle
|
26 |
+
|
27 |
+
class Executor:
|
28 |
+
def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None:
|
29 |
+
IMPLEMENTED_METHODS = ["blur", "full", "gray"]
|
30 |
+
if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")):
|
31 |
+
raise NotImplementedError
|
32 |
+
IMPLEMENTED_AGGREGATORS = ["max", "sum"]
|
33 |
+
if method_aggregator not in IMPLEMENTED_AGGREGATORS:
|
34 |
+
raise NotImplementedError
|
35 |
+
self.box_representation_method = box_representation_method
|
36 |
+
self.method_aggregator = method_aggregator
|
37 |
+
self.enlarge_boxes = enlarge_boxes
|
38 |
+
self.device = device
|
39 |
+
self.expand_position_embedding = expand_position_embedding
|
40 |
+
self.square_size = square_size
|
41 |
+
self.blur_std_dev = blur_std_dev
|
42 |
+
self.cache_path = cache_path
|
43 |
+
|
44 |
+
def preprocess_image(self, image: Image) -> List[torch.Tensor]:
|
45 |
+
return [preprocess(image) for preprocess in self.preprocesses]
|
46 |
+
|
47 |
+
def preprocess_mask(self, mask: Image) -> List[torch.Tensor]:
|
48 |
+
preprocess = self.preprocesses[0]
|
49 |
+
return preprocess.transforms[1](preprocess.transforms[0](mask))
|
50 |
+
|
51 |
+
def preprocess_text(self, text: str) -> torch.Tensor:
|
52 |
+
raise NotImplementedError
|
53 |
+
|
54 |
+
def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor:
|
55 |
+
raise NotImplementedError
|
56 |
+
|
57 |
+
def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]:
|
58 |
+
images = []
|
59 |
+
for preprocess in self.preprocesses:
|
60 |
+
images.append([])
|
61 |
+
|
62 |
+
if 'aclip' in self.clip_type:
|
63 |
+
self.all_masks = []
|
64 |
+
read_save = False
|
65 |
+
if self.mask_path is not None: # load mask if cached
|
66 |
+
file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl'
|
67 |
+
if os.path.exists(os.path.join(self.mask_path, file_name)):
|
68 |
+
all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb'))
|
69 |
+
for rle in all_rles:
|
70 |
+
mask = np.array(mask_utils.decode(rle), dtype=bool)
|
71 |
+
self.all_masks.append(mask)
|
72 |
+
read_save = True
|
73 |
+
if not read_save:
|
74 |
+
# use SAM to generate masks
|
75 |
+
self.predictor.set_image(np.array(image.convert('RGB')))
|
76 |
+
all_rles = []
|
77 |
+
for i in range(len(boxes)):
|
78 |
+
box = [
|
79 |
+
max(boxes[i].left-self.enlarge_boxes, 0),
|
80 |
+
max(boxes[i].top-self.enlarge_boxes, 0),
|
81 |
+
min(boxes[i].right+self.enlarge_boxes, image.width),
|
82 |
+
min(boxes[i].bottom+self.enlarge_boxes, image.height)
|
83 |
+
] # box prompt
|
84 |
+
input_box = np.array(box)
|
85 |
+
masks, _, _ = self.predictor.predict(
|
86 |
+
point_coords=None,
|
87 |
+
point_labels=None,
|
88 |
+
box=input_box[None, :],
|
89 |
+
multimask_output=False,
|
90 |
+
)
|
91 |
+
self.all_masks.append(masks[0])
|
92 |
+
rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0]
|
93 |
+
rle["counts"] = rle["counts"].decode("utf-8")
|
94 |
+
all_rles.append(rle)
|
95 |
+
if self.mask_path is not None: # save mask
|
96 |
+
os.makedirs(self.mask_path, exist_ok=True)
|
97 |
+
pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb'))
|
98 |
+
|
99 |
+
if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]):
|
100 |
+
if "full" in self.box_representation_method: # original full image with alpha-map
|
101 |
+
for i in range(len(boxes)):
|
102 |
+
image_i = image.copy()
|
103 |
+
preprocessed_images = self.preprocess_image(image_i)
|
104 |
+
for j, img in enumerate(preprocessed_images):
|
105 |
+
images[j].append(img.to(self.device))
|
106 |
+
if "blur" in self.box_representation_method:
|
107 |
+
for i in range(len(boxes)):
|
108 |
+
image_i = image.copy()
|
109 |
+
|
110 |
+
mask = Image.new('L', image_i.size, 0)
|
111 |
+
draw = ImageDraw.Draw(mask)
|
112 |
+
box = (
|
113 |
+
max(boxes[i].left-self.enlarge_boxes, 0),
|
114 |
+
max(boxes[i].top-self.enlarge_boxes, 0),
|
115 |
+
min(boxes[i].right+self.enlarge_boxes, image_i.width),
|
116 |
+
min(boxes[i].bottom+self.enlarge_boxes, image_i.height)
|
117 |
+
)
|
118 |
+
if 'aclip' in self.clip_type:
|
119 |
+
width, height = image.size
|
120 |
+
for y in range(height):
|
121 |
+
for x in range(width):
|
122 |
+
if self.all_masks[i][y][x] == 1:
|
123 |
+
draw.point((x, y), fill=255)
|
124 |
+
else:
|
125 |
+
draw.rectangle([box[:2], box[2:]], fill=255)
|
126 |
+
blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev))
|
127 |
+
blurred.paste(image_i, mask=mask)
|
128 |
+
preprocessed_images = self.preprocess_image(blurred)
|
129 |
+
|
130 |
+
for j, img in enumerate(preprocessed_images):
|
131 |
+
images[j].append(img.to(self.device))
|
132 |
+
if "gray" in self.box_representation_method:
|
133 |
+
for i in range(len(boxes)):
|
134 |
+
image_i = image.copy()
|
135 |
+
mask_i = self.all_masks[i]
|
136 |
+
width, height = image.size
|
137 |
+
|
138 |
+
pixels = image_i.load()
|
139 |
+
for y in range(height):
|
140 |
+
for x in range(width):
|
141 |
+
if mask_i[y][x] == 0:
|
142 |
+
pixel_value = pixels[x, y]
|
143 |
+
gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2])
|
144 |
+
pixels[x, y] = (gray_value, gray_value, gray_value)
|
145 |
+
preprocessed_images = self.preprocess_image(image_i)
|
146 |
+
for j, img in enumerate(preprocessed_images):
|
147 |
+
images[j].append(img.to(self.device))
|
148 |
+
|
149 |
+
imgs = [torch.stack(image_list) for image_list in images]
|
150 |
+
else:
|
151 |
+
imgs = [[] for _ in self.models]
|
152 |
+
text_tensor = self.preprocess_text(caption.lower()).to(self.device)
|
153 |
+
return imgs, text_tensor
|
154 |
+
|
155 |
+
@torch.no_grad()
|
156 |
+
def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
|
157 |
+
images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth)
|
158 |
+
all_logits_per_image = []
|
159 |
+
all_logits_per_text = []
|
160 |
+
box_representation_methods = self.box_representation_method.split(',')
|
161 |
+
caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest()
|
162 |
+
for model, images_t, model_name in zip(self.models, images, self.model_names):
|
163 |
+
self.image_feat_path = ""
|
164 |
+
if self.cache_path is not None:
|
165 |
+
text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else ""))
|
166 |
+
image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name)
|
167 |
+
self.image_feat_path = image_feat_path
|
168 |
+
image_features = None
|
169 |
+
text_features = None
|
170 |
+
if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)):
|
171 |
+
if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")):
|
172 |
+
text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device)
|
173 |
+
if os.path.exists(image_feat_path):
|
174 |
+
if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]):
|
175 |
+
image_features = []
|
176 |
+
for method_name in box_representation_methods:
|
177 |
+
features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device)
|
178 |
+
image_features.append(torch.stack([
|
179 |
+
features[(box.x, box.y, box.w, box.h)]
|
180 |
+
for box in boxes
|
181 |
+
]))
|
182 |
+
image_features = torch.stack(image_features)
|
183 |
+
image_features = image_features.view(-1, image_features.shape[-1])
|
184 |
+
logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth)
|
185 |
+
all_logits_per_image.append(logits_per_image)
|
186 |
+
all_logits_per_text.append(logits_per_text)
|
187 |
+
if self.cache_path is not None and image_name is not None and image_features is not None:
|
188 |
+
image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1])
|
189 |
+
if not os.path.exists(image_feat_path):
|
190 |
+
os.makedirs(image_feat_path)
|
191 |
+
for i in range(image_features.shape[0]):
|
192 |
+
method_name = box_representation_methods[i]
|
193 |
+
if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")):
|
194 |
+
image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)}
|
195 |
+
torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt"))
|
196 |
+
if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None:
|
197 |
+
assert text_features.shape[0] == 1
|
198 |
+
if not os.path.exists(text_cache_path):
|
199 |
+
os.makedirs(text_cache_path)
|
200 |
+
torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt"))
|
201 |
+
|
202 |
+
all_logits_per_image = torch.stack(all_logits_per_image).sum(0)
|
203 |
+
all_logits_per_text = torch.stack(all_logits_per_text).sum(0)
|
204 |
+
if self.method_aggregator == "max":
|
205 |
+
all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0]
|
206 |
+
elif self.method_aggregator == "sum":
|
207 |
+
all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True)
|
208 |
+
return all_logits_per_text.view(-1)
|
209 |
+
|
210 |
+
class ClipExecutor(Executor):
|
211 |
+
def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None:
|
212 |
+
super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path)
|
213 |
+
self.clip_models = clip_model.split(",")
|
214 |
+
self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models]
|
215 |
+
self.models = []
|
216 |
+
self.preprocesses = []
|
217 |
+
self.data_name = input_file.split('/')[-1].split('.')[0]
|
218 |
+
self.mask_path = None
|
219 |
+
self.clip_type = clip_type
|
220 |
+
if self.cache_path is not None:
|
221 |
+
self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks')
|
222 |
+
sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth"
|
223 |
+
model_type = "vit_h"
|
224 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
225 |
+
sam.to(device=device)
|
226 |
+
self.predictor = SamPredictor(sam)
|
227 |
+
for model_name in self.clip_models:
|
228 |
+
if 'aclip' in self.clip_type:#using alpha-clip
|
229 |
+
self.mask_transform = transforms.Compose([
|
230 |
+
transforms.ToTensor(),
|
231 |
+
transforms.Resize((224, 224)),
|
232 |
+
transforms.Normalize(0.5, 0.26)
|
233 |
+
])
|
234 |
+
if model_name == 'ViT-B/16':
|
235 |
+
model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device)
|
236 |
+
elif model_name == 'ViT-L/14':
|
237 |
+
model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device)
|
238 |
+
|
239 |
+
else: model, preprocess = clip.load(model_name, device=device, jit=False)
|
240 |
+
self.models.append(model)
|
241 |
+
if self.square_size:
|
242 |
+
print("Square size!")
|
243 |
+
preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC)
|
244 |
+
self.preprocesses.append(preprocess)
|
245 |
+
self.models = torch.nn.ModuleList(self.models)
|
246 |
+
|
247 |
+
def preprocess_text(self, text: str) -> torch.Tensor:
|
248 |
+
if "aclip" in self.box_representation_method:
|
249 |
+
return alpha_clip.tokenize([text.lower()])
|
250 |
+
if "shade" in self.box_representation_method:
|
251 |
+
return clip.tokenize([text.lower()+" is in red color."])
|
252 |
+
return clip.tokenize(["a photo of "+text.lower()])
|
253 |
+
|
254 |
+
def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor:
|
255 |
+
if image_features is None:
|
256 |
+
print('computing image features')
|
257 |
+
if 'aclip' not in self.clip_type:
|
258 |
+
image_features = model.encode_image(images)
|
259 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
260 |
+
else:
|
261 |
+
image_features = []
|
262 |
+
if 'full' in self.box_representation_method:
|
263 |
+
aclip_images = images[:len(boxes)]
|
264 |
+
alphas = []
|
265 |
+
|
266 |
+
if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')):
|
267 |
+
features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device)
|
268 |
+
aclip_image_features = torch.stack([
|
269 |
+
features[(box.x, box.y, box.w, box.h)]
|
270 |
+
for box in boxes
|
271 |
+
])
|
272 |
+
else:
|
273 |
+
for i in range(len(self.all_masks)):
|
274 |
+
binary_mask = self.all_masks[i]
|
275 |
+
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
|
276 |
+
alpha = alpha.half().cuda().unsqueeze(dim=0)
|
277 |
+
alphas.append(alpha)
|
278 |
+
|
279 |
+
alphas = torch.cat(alphas, dim=0)
|
280 |
+
aclip_images = aclip_images.half()
|
281 |
+
aclip_image_features = model.visual(aclip_images, alphas) # using alpha channels
|
282 |
+
images = images[len(boxes):]
|
283 |
+
image_features.append(aclip_image_features)
|
284 |
+
|
285 |
+
if 'blur' in self.box_representation_method:
|
286 |
+
if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')):
|
287 |
+
features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device)
|
288 |
+
ablur_images_features = torch.stack([
|
289 |
+
features[(box.x, box.y, box.w, box.h)]
|
290 |
+
for box in boxes
|
291 |
+
])
|
292 |
+
else:
|
293 |
+
ablur_images = images[:len(boxes)]
|
294 |
+
alphas = []
|
295 |
+
for i in range(len(self.all_masks)):
|
296 |
+
binary_mask = self.all_masks[i]
|
297 |
+
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
|
298 |
+
alpha = alpha.half().cuda().unsqueeze(dim=0)
|
299 |
+
alphas.append(alpha)
|
300 |
+
alphas = torch.cat(alphas, dim=0)
|
301 |
+
ablur_images = ablur_images.half()
|
302 |
+
ablur_images_features = model.visual(ablur_images, alphas)
|
303 |
+
images = images[len(boxes):]
|
304 |
+
image_features.append(ablur_images_features)
|
305 |
+
|
306 |
+
if 'gray' in self.box_representation_method:
|
307 |
+
if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')):
|
308 |
+
features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device)
|
309 |
+
gray_images_features = torch.stack([
|
310 |
+
features[(box.x, box.y, box.w, box.h)]
|
311 |
+
for box in boxes
|
312 |
+
])
|
313 |
+
else:
|
314 |
+
gray_images = images[:len(boxes)]
|
315 |
+
alphas = []
|
316 |
+
for i in range(len(self.all_masks)):
|
317 |
+
binary_mask = self.all_masks[i]
|
318 |
+
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8))
|
319 |
+
alpha = alpha.half().cuda().unsqueeze(dim=0)
|
320 |
+
alphas.append(alpha)
|
321 |
+
alphas = torch.cat(alphas, dim=0)
|
322 |
+
gray_images = gray_images.half()
|
323 |
+
gray_images_features = model.visual(gray_images, alphas)
|
324 |
+
images = images[len(boxes):]
|
325 |
+
image_features.append(gray_images_features)
|
326 |
+
|
327 |
+
|
328 |
+
image_features = torch.cat(image_features, dim=0)
|
329 |
+
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
330 |
+
|
331 |
+
if text_features is None:
|
332 |
+
print('computing text features')
|
333 |
+
text_features = model.encode_text(text)
|
334 |
+
# normalized features
|
335 |
+
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
336 |
+
|
337 |
+
# cosine similarity as logits
|
338 |
+
logit_scale = model.logit_scale.exp()
|
339 |
+
logits_per_image = logit_scale * image_features @ text_features.t()
|
340 |
+
logits_per_text = logits_per_image.t()
|
341 |
+
return logits_per_image, logits_per_text, image_features, text_features
|
342 |
+
|
343 |
+
def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor:
|
344 |
+
if self.expand_position_embedding:
|
345 |
+
original_preprocesses = self.preprocesses
|
346 |
+
new_preprocesses = []
|
347 |
+
original_position_embeddings = []
|
348 |
+
for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses):
|
349 |
+
if "RN" in model_name:
|
350 |
+
model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5)
|
351 |
+
patch_size = model.visual.input_resolution // model_spatial_dim
|
352 |
+
original_positional_embedding = model.visual.attnpool.positional_embedding.clone()
|
353 |
+
model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
|
354 |
+
model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
|
355 |
+
size=(image.height // patch_size, image.width // patch_size),
|
356 |
+
mode='bicubic',
|
357 |
+
align_corners=False
|
358 |
+
).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
|
359 |
+
model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat((
|
360 |
+
original_positional_embedding[:1,:],
|
361 |
+
model.visual.attnpool.positional_embedding
|
362 |
+
), dim=0))
|
363 |
+
transform = transforms.Compose([
|
364 |
+
transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
|
365 |
+
lambda image: image.convert("RGB"),
|
366 |
+
transforms.ToTensor(),
|
367 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
368 |
+
])
|
369 |
+
else:
|
370 |
+
model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5)
|
371 |
+
patch_size = model.visual.input_resolution // model_spatial_dim
|
372 |
+
original_positional_embedding = model.visual.positional_embedding.clone()
|
373 |
+
model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate(
|
374 |
+
model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim),
|
375 |
+
size=(image.height // patch_size, image.width // patch_size),
|
376 |
+
mode='bicubic',
|
377 |
+
align_corners=False
|
378 |
+
).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1]))
|
379 |
+
model.visual.positional_embedding = torch.nn.Parameter(torch.cat((
|
380 |
+
original_positional_embedding[:1,:],
|
381 |
+
model.visual.positional_embedding
|
382 |
+
), dim=0))
|
383 |
+
transform = transforms.Compose([
|
384 |
+
transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC),
|
385 |
+
lambda image: image.convert("RGB"),
|
386 |
+
transforms.ToTensor(),
|
387 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
388 |
+
])
|
389 |
+
new_preprocesses.append(transform)
|
390 |
+
original_position_embeddings.append(original_positional_embedding)
|
391 |
+
self.preprocesses = new_preprocesses
|
392 |
+
result = super().__call__(caption, image, boxes, image_name, image_pth)
|
393 |
+
if self.expand_position_embedding:
|
394 |
+
self.preprocesses = original_preprocesses
|
395 |
+
for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings):
|
396 |
+
if "RN" in model_name:
|
397 |
+
model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding)
|
398 |
+
else:
|
399 |
+
model.visual.positional_embedding = torch.nn.Parameter(pos_embedding)
|
400 |
+
return result
|
401 |
+
|
AlphaCLIP/eval/rec_zs_test/generic_clip_pairs.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import clip
|
3 |
+
import json
|
4 |
+
import argparse
|
5 |
+
import ruamel.yaml as yaml
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from albef.utils import *
|
13 |
+
from executor import AlbefExecutor
|
14 |
+
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--input_path", type=str, help="Path to input JSON file")
|
17 |
+
parser.add_argument("--image_root", type=str, help="Path to directory containing images")
|
18 |
+
parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF")
|
19 |
+
parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF")
|
20 |
+
parser.add_argument("--clip_model", type=str, help="CLIP model to use")
|
21 |
+
parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use")
|
22 |
+
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP")
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
if args.albef_path is not None:
|
27 |
+
executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu))
|
28 |
+
model = executor.models[0]
|
29 |
+
preprocess = executor.preprocesses[0]
|
30 |
+
model = model.eval()
|
31 |
+
else:
|
32 |
+
model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu))
|
33 |
+
preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC)
|
34 |
+
model = model.eval()
|
35 |
+
input_file = open(args.input_path)
|
36 |
+
data = json.load(input_file)
|
37 |
+
input_file.close()
|
38 |
+
correct = 0
|
39 |
+
for i in tqdm(range(0, len(data), args.batch_size)):
|
40 |
+
batch_images = []
|
41 |
+
batch_text = []
|
42 |
+
for datum in data[i:min(i+args.batch_size, len(data))]:
|
43 |
+
img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB')
|
44 |
+
batch_images.append(preprocess(img))
|
45 |
+
if "text2" in datum:
|
46 |
+
if args.albef_path is None:
|
47 |
+
datum["text1"] = "a photo of "+datum["text1"]
|
48 |
+
datum["text2"] = "a photo of "+datum["text2"]
|
49 |
+
batch_text.append(datum["text1"])
|
50 |
+
batch_text.append(datum["text2"])
|
51 |
+
else:
|
52 |
+
img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB')
|
53 |
+
batch_images.append(preprocess(img2))
|
54 |
+
batch_text.append(datum["text1"])
|
55 |
+
batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu))
|
56 |
+
if args.albef_path is None:
|
57 |
+
batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu))
|
58 |
+
else:
|
59 |
+
modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text]
|
60 |
+
batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt")
|
61 |
+
for key in batch_text:
|
62 |
+
batch_text[key] = batch_text[key].to(batch_images.device)
|
63 |
+
|
64 |
+
with torch.no_grad():
|
65 |
+
if args.albef_path is None:
|
66 |
+
logits_per_image, logits_per_text = model(batch_images, batch_text)
|
67 |
+
else:
|
68 |
+
if not args.albef_itc:
|
69 |
+
if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]:
|
70 |
+
batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3])
|
71 |
+
else:
|
72 |
+
assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0]
|
73 |
+
batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
|
74 |
+
batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
|
75 |
+
image_embeds = model.visual_encoder(batch_images)
|
76 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device)
|
77 |
+
output = model.text_encoder(
|
78 |
+
batch_text.input_ids,
|
79 |
+
attention_mask = batch_text.attention_mask,
|
80 |
+
encoder_hidden_states = image_embeds,
|
81 |
+
encoder_attention_mask = image_atts,
|
82 |
+
return_dict = True,
|
83 |
+
)
|
84 |
+
vl_embeddings = output.last_hidden_state[:,0,:]
|
85 |
+
vl_output = model.itm_head(vl_embeddings)
|
86 |
+
logits_per_image = vl_output[:,1:2].view(-1, 2)
|
87 |
+
else:
|
88 |
+
image_embeds = model.visual_encoder(batch_images)
|
89 |
+
image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1)
|
90 |
+
text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask,
|
91 |
+
return_dict = True, mode = 'text')
|
92 |
+
text_embeds = text_output.last_hidden_state
|
93 |
+
text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1)
|
94 |
+
sim = image_feat@text_feat.t()/model.temp
|
95 |
+
logits_per_image = sim
|
96 |
+
if args.albef_path is None or args.albef_itc:
|
97 |
+
if logits_per_image.shape[0]*2 == logits_per_image.shape[1]:
|
98 |
+
for j in range(logits_per_image.shape[0]):
|
99 |
+
correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0
|
100 |
+
else:
|
101 |
+
assert logits_per_image.shape[0] == 2*logits_per_image.shape[1]
|
102 |
+
for j in range(logits_per_image.shape[1]):
|
103 |
+
correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0
|
104 |
+
else:
|
105 |
+
correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item()
|
106 |
+
|
107 |
+
print("Accuracy:", correct/len(data))
|
AlphaCLIP/eval/rec_zs_test/heuristics.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Heuristic rules used to extract and execute entity parses."""
|
2 |
+
|
3 |
+
from typing import Callable, List, NamedTuple
|
4 |
+
from argparse import Namespace
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class RelHeuristic(NamedTuple):
|
9 |
+
keywords: List[str]
|
10 |
+
callback: Callable[["Environment"], np.ndarray]
|
11 |
+
|
12 |
+
|
13 |
+
class Heuristics:
|
14 |
+
"""A class defining heuristics that can be enabled/disabled."""
|
15 |
+
|
16 |
+
RELATIONS = [
|
17 |
+
RelHeuristic(["left", "west"], lambda env: env.left_of()),
|
18 |
+
RelHeuristic(["right", "east"], lambda env: env.right_of()),
|
19 |
+
RelHeuristic(["above", "north", "top", "back", "behind"], lambda env: env.above()),
|
20 |
+
RelHeuristic(["below", "south", "under", "front"], lambda env: env.below()),
|
21 |
+
RelHeuristic(["bigger", "larger", "closer"], lambda env: env.bigger_than()),
|
22 |
+
RelHeuristic(["smaller", "tinier", "further"], lambda env: env.smaller_than()),
|
23 |
+
RelHeuristic(["inside", "within", "contained"], lambda env: env.within()),
|
24 |
+
]
|
25 |
+
|
26 |
+
TERNARY_RELATIONS = [
|
27 |
+
RelHeuristic(["between"], lambda env: env.between()),
|
28 |
+
]
|
29 |
+
|
30 |
+
SUPERLATIVES = [
|
31 |
+
RelHeuristic(["left", "west", "leftmost", "western"], lambda env: env.left_of()),
|
32 |
+
RelHeuristic(["right", "rightmost", "east", "eastern"], lambda env: env.right_of()),
|
33 |
+
RelHeuristic(["above", "north", "top"], lambda env: env.above()),
|
34 |
+
RelHeuristic(["below", "south", "underneath", "front"], lambda env: env.below()),
|
35 |
+
RelHeuristic(["bigger", "biggest", "larger", "largest", "closer", "closest"], lambda env: env.bigger_than()),
|
36 |
+
RelHeuristic(["smaller", "smallest", "tinier", "tiniest", "further", "furthest"], lambda env: env.smaller_than()),
|
37 |
+
]
|
38 |
+
OPPOSITES = {0: 1, 1: 0, 2: 3, 3: 2, 4: 5, 5: 4}
|
39 |
+
|
40 |
+
NULL_KEYWORDS = ["part", "image", "side", "picture", "half", "region", "section"]
|
41 |
+
|
42 |
+
EMPTY = []
|
43 |
+
|
44 |
+
def __init__(self, args: Namespace = None):
|
45 |
+
self.enable_relations = not args or not args.no_rel
|
46 |
+
self.enable_superlatives = not args or not args.no_sup
|
47 |
+
self.enable_nulls = not args or not args.no_null
|
48 |
+
self.enable_ternary = not args or args.ternary
|
49 |
+
|
50 |
+
@property
|
51 |
+
def relations(self) -> List[RelHeuristic]:
|
52 |
+
return self.RELATIONS if self.enable_relations else self.EMPTY
|
53 |
+
|
54 |
+
@property
|
55 |
+
def ternary_relations(self) -> List[RelHeuristic]:
|
56 |
+
return self.TERNARY_RELATIONS if self.enable_ternary else self.EMPTY
|
57 |
+
|
58 |
+
@property
|
59 |
+
def superlatives(self) -> List[RelHeuristic]:
|
60 |
+
return self.SUPERLATIVES if self.enable_superlatives else self.EMPTY
|
61 |
+
|
62 |
+
@property
|
63 |
+
def opposites(self):
|
64 |
+
return self.OPPOSITES
|
65 |
+
|
66 |
+
@property
|
67 |
+
def null_keywords(self) -> List[str]:
|
68 |
+
return self.NULL_KEYWORDS if self.enable_nulls else self.EMPTY
|
AlphaCLIP/eval/rec_zs_test/interpreter.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple, List, Callable
|
2 |
+
import sys
|
3 |
+
import re
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from numpy.linalg import norm
|
7 |
+
from itertools import product, groupby
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
# Do two line segments intersect? Copied from
|
12 |
+
# https://stackoverflow.com/questions/3838329/how-can-i-check-if-two-segments-intersect
|
13 |
+
|
14 |
+
|
15 |
+
def ccw(A, B, C):
|
16 |
+
return (C.y - A.y) * (B.x - A.x) > (B.y - A.y) * (C.x - A.x)
|
17 |
+
|
18 |
+
|
19 |
+
def intersect(A, B, C, D):
|
20 |
+
"""Do line segments AB and CD intersect?"""
|
21 |
+
return ccw(A, C, D) != ccw(B, C, D) and ccw(A, B, C) != ccw(A, B, D)
|
22 |
+
|
23 |
+
|
24 |
+
class Box(NamedTuple):
|
25 |
+
x: int
|
26 |
+
y: int
|
27 |
+
w: int = 0
|
28 |
+
h: int = 0
|
29 |
+
|
30 |
+
@property
|
31 |
+
def left(self):
|
32 |
+
return self.x
|
33 |
+
|
34 |
+
@property
|
35 |
+
def right(self):
|
36 |
+
return self.x + self.w
|
37 |
+
|
38 |
+
@property
|
39 |
+
def top(self):
|
40 |
+
return self.y
|
41 |
+
|
42 |
+
@property
|
43 |
+
def bottom(self):
|
44 |
+
return self.y + self.h
|
45 |
+
|
46 |
+
@property
|
47 |
+
def center(self):
|
48 |
+
return Box(self.x + self.w // 2, self.y + self.h // 2)
|
49 |
+
|
50 |
+
def corners(self):
|
51 |
+
yield Box(self.x, self.y)
|
52 |
+
yield Box(self.x + self.w, self.y)
|
53 |
+
yield Box(self.x + self.w, self.y + self.h)
|
54 |
+
yield Box(self.x, self.y + self.h)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def area(self):
|
58 |
+
return self.w * self.h
|
59 |
+
|
60 |
+
def intersect(self, other: "Box") -> "Box":
|
61 |
+
x1 = max(self.x, other.x)
|
62 |
+
x2 = max(x1, min(self.x+self.w, other.x+other.w))
|
63 |
+
y1 = max(self.y, other.y)
|
64 |
+
y2 = max(y1, min(self.y+self.h, other.y+other.h))
|
65 |
+
return Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
|
66 |
+
|
67 |
+
def min_bounding(self, other: "Box") -> "Box":
|
68 |
+
corners = list(self.corners())
|
69 |
+
corners.extend(other.corners())
|
70 |
+
min_x = min_y = float("inf")
|
71 |
+
max_x = max_y = -float("inf")
|
72 |
+
|
73 |
+
for item in corners:
|
74 |
+
min_x = min(min_x, item.x)
|
75 |
+
min_y = min(min_y, item.y)
|
76 |
+
max_x = max(max_x, item.x)
|
77 |
+
max_y = max(max_y, item.y)
|
78 |
+
|
79 |
+
return Box(min_x, min_y, max_x - min_x, max_y - min_y)
|
80 |
+
|
81 |
+
def expand(self, growth: float = .1) -> "Box":
|
82 |
+
factor = 1 + growth
|
83 |
+
w = factor * self.w
|
84 |
+
h = factor * self.h
|
85 |
+
return Box(min_x - (w - self.w) / 2, min_y - (h - self.h) / 2, w, h)
|
86 |
+
|
87 |
+
|
88 |
+
def iou(box1, box2):
|
89 |
+
x1 = max(box1.x, box2.x)
|
90 |
+
x2 = max(x1, min(box1.x+box1.w, box2.x+box2.w))
|
91 |
+
y1 = max(box1.y, box2.y)
|
92 |
+
y2 = max(y1, min(box1.y+box1.h, box2.y+box2.h))
|
93 |
+
intersection = Box(x=x1, y=y1, w=x2-x1, h=y2-y1)
|
94 |
+
intersection_area = intersection.area
|
95 |
+
union_area = box1.area+box2.area-intersection_area
|
96 |
+
return intersection_area / union_area
|
97 |
+
|
98 |
+
|
99 |
+
def all_equal(iterable):
|
100 |
+
"""Are all elements the same?"""
|
101 |
+
g = groupby(iterable)
|
102 |
+
return next(g, True) and not next(g, False)
|
103 |
+
|
104 |
+
|
105 |
+
class spatial:
|
106 |
+
"""A decorator that converts a predicate over boxes to a function that returns a tensor over all boxes."""
|
107 |
+
|
108 |
+
def __init__(self, arity: int = 2, enforce_antisymmetry: bool = False):
|
109 |
+
self.arity = arity
|
110 |
+
self.enforce_antisymmetry = enforce_antisymmetry # Zero out any entries where two boxes are the same.
|
111 |
+
|
112 |
+
def __call__(self, predicate: Callable[[Box], float]) -> Callable[["Environment"], np.ndarray]:
|
113 |
+
def _rel(env):
|
114 |
+
n_boxes = len(env.boxes)
|
115 |
+
tensor = np.empty([n_boxes for _ in range(self.arity)])
|
116 |
+
enum_boxes = list(enumerate(env.boxes))
|
117 |
+
for pairs in product(*[enum_boxes for _ in range(self.arity)]):
|
118 |
+
indices, boxes = zip(*pairs)
|
119 |
+
if self.enforce_antisymmetry and len(set(indices)) < len(indices):
|
120 |
+
tensor[indices] = 0.
|
121 |
+
else:
|
122 |
+
tensor[indices] = predicate(*boxes)
|
123 |
+
return tensor
|
124 |
+
return _rel
|
125 |
+
|
126 |
+
|
127 |
+
class Environment:
|
128 |
+
def __init__(self, image: Image, boxes: List[Box], executor: "Executor" = None, freeform_boxes: bool = False, image_name: str = None, image_pth: str=None):
|
129 |
+
self.image = image
|
130 |
+
self.boxes = boxes
|
131 |
+
self.executor = executor # An object or callback that can query CLIP with captions/images.
|
132 |
+
self.freeform_boxes = freeform_boxes
|
133 |
+
self.image_name = image_name
|
134 |
+
self.image_pth=image_pth
|
135 |
+
|
136 |
+
def uniform(self) -> np.ndarray:
|
137 |
+
n_boxes = len(self.boxes)
|
138 |
+
return 1 / n_boxes * np.ones(n_boxes)
|
139 |
+
|
140 |
+
def filter(self,
|
141 |
+
caption: str,
|
142 |
+
temperature: float = 1.,
|
143 |
+
area_threshold: float = 0.0,
|
144 |
+
softmax: bool = False,
|
145 |
+
expand: float = None
|
146 |
+
) -> np.ndarray:
|
147 |
+
"""Return a new distribution reflecting the likelihood that `caption` describes the content of each box."""
|
148 |
+
area_filtered_dist = torch.from_numpy(self.filter_area(area_threshold)).to(self.executor.device)
|
149 |
+
candidate_indices = [i for i in range(len(self.boxes)) if float(area_filtered_dist[i]) > 0.0]
|
150 |
+
boxes = [self.boxes[i] for i in candidate_indices]
|
151 |
+
if len(boxes) == 0:
|
152 |
+
boxes = self.boxes
|
153 |
+
candidate_indices = list(range(len(boxes)))
|
154 |
+
if expand is not None:
|
155 |
+
boxes = [box.expand(expand) for box in boxes]
|
156 |
+
result_partial = self.executor(caption, self.image, boxes, image_name=self.image_name, image_pth=self.image_pth)
|
157 |
+
if self.freeform_boxes:
|
158 |
+
result_partial, boxes = result_partial
|
159 |
+
self.boxes = [Box(x=boxes[i,0].item(), y=boxes[i,1].item(), w=boxes[i,2].item()-boxes[i,0].item(), h=boxes[i,3].item()-boxes[i,1].item()) for i in range(boxes.shape[0])]
|
160 |
+
candidate_indices = list(range(len(self.boxes)))
|
161 |
+
result_partial = result_partial.float()
|
162 |
+
if not softmax:
|
163 |
+
result_partial = (result_partial-result_partial.mean()) / (result_partial.std() + 1e-9)
|
164 |
+
result_partial = (temperature * result_partial).sigmoid()
|
165 |
+
result = torch.zeros((len(self.boxes))).to(result_partial.device)
|
166 |
+
result[candidate_indices] = result_partial
|
167 |
+
else:
|
168 |
+
result = torch.zeros((len(self.boxes))).to(result_partial.device)
|
169 |
+
result[candidate_indices] = result_partial.softmax(dim=-1) #softmax结果
|
170 |
+
return result.cpu().numpy()
|
171 |
+
|
172 |
+
def filter_area(self, area_threshold: float) -> np.ndarray:
|
173 |
+
"""Return a new distribution in which all boxes whose area as a fraction of the image is less than the threshold."""
|
174 |
+
image_area = self.image.width*self.image.height
|
175 |
+
return np.array([1 if self.boxes[i].area/image_area > area_threshold else 0 for i in range(len(self.boxes))])
|
176 |
+
|
177 |
+
@spatial()
|
178 |
+
def left_of(b1, b2):
|
179 |
+
return (b1.right+b1.left) / 2 < (b2.right+b2.left) / 2
|
180 |
+
|
181 |
+
@spatial()
|
182 |
+
def right_of(b1, b2):
|
183 |
+
return (b1.right+b1.left) / 2 > (b2.right+b2.left) / 2
|
184 |
+
|
185 |
+
@spatial()
|
186 |
+
def above(b1, b2):
|
187 |
+
return (b1.bottom+b1.top) < (b2.bottom+b2.top)
|
188 |
+
|
189 |
+
@spatial()
|
190 |
+
def below(b1, b2):
|
191 |
+
return (b1.bottom+b1.top) > (b2.bottom+b2.top)
|
192 |
+
|
193 |
+
@spatial()
|
194 |
+
def bigger_than(b1, b2):
|
195 |
+
return b1.area > b2.area
|
196 |
+
|
197 |
+
@spatial()
|
198 |
+
def smaller_than(b1, b2):
|
199 |
+
return b1.area < b2.area
|
200 |
+
|
201 |
+
@spatial(enforce_antisymmetry=False)
|
202 |
+
def within(box1, box2):
|
203 |
+
"""Return percent of box1 inside box2."""
|
204 |
+
intersection = box1.intersect(box2)
|
205 |
+
return intersection.area / box1.area
|
206 |
+
|
207 |
+
@spatial(arity=3, enforce_antisymmetry=True)
|
208 |
+
def between(box1, box2, box3):
|
209 |
+
"""How much of box1 lies in min bounding box over box2 and box3?"""
|
210 |
+
min_bounding = box2.min_bounding(box3)
|
211 |
+
intersect = box1.intersect(min_bounding)
|
212 |
+
return intersect.area / box1.area
|
AlphaCLIP/eval/rec_zs_test/lattice.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implement lattice interface."""
|
2 |
+
|
3 |
+
from overrides import overrides
|
4 |
+
import numpy as np
|
5 |
+
from abc import ABCMeta, abstractmethod
|
6 |
+
|
7 |
+
|
8 |
+
class Lattice(metaclass=ABCMeta):
|
9 |
+
|
10 |
+
"""Abstract base class representing a complemented lattice."""
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
@abstractmethod
|
14 |
+
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
|
15 |
+
return NotImplemented
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
@abstractmethod
|
19 |
+
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
|
20 |
+
return NotImplemented
|
21 |
+
|
22 |
+
@classmethod
|
23 |
+
@abstractmethod
|
24 |
+
def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
|
25 |
+
return NotImplemented
|
26 |
+
|
27 |
+
@classmethod
|
28 |
+
@abstractmethod
|
29 |
+
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
|
30 |
+
return NotImplemented
|
31 |
+
|
32 |
+
|
33 |
+
class Product(Lattice):
|
34 |
+
"""Lattice where meet=prod and sum is defined accordingly.
|
35 |
+
|
36 |
+
Equivalent to assuming independence, more or less.
|
37 |
+
"""
|
38 |
+
|
39 |
+
eps = 1e-9
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
@overrides
|
43 |
+
def join(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
|
44 |
+
return probs1 + probs2 - cls.meet(probs1, probs2)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
@overrides
|
48 |
+
def meet(cls, probs1: np.ndarray, probs2: np.ndarray) -> np.ndarray:
|
49 |
+
return probs1 * probs2
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
@overrides
|
53 |
+
def join_reduce(cls, probs: np.ndarray) -> np.ndarray:
|
54 |
+
"""Assumes disjoint events."""
|
55 |
+
# return cls.comp(cls.meet_reduce(cls.comp(probs)))
|
56 |
+
return np.sum(probs, axis=-1)
|
57 |
+
|
58 |
+
@classmethod
|
59 |
+
@overrides
|
60 |
+
def meet_reduce(cls, probs: np.ndarray) -> np.ndarray:
|
61 |
+
return np.prod(probs, axis=-1)
|
62 |
+
|
63 |
+
@classmethod
|
64 |
+
def comp(cls, probs):
|
65 |
+
return 1 - probs
|
66 |
+
|
67 |
+
@classmethod
|
68 |
+
def normalize(cls, probs):
|
69 |
+
"""Normalize a distribution by dividing by the total mass."""
|
70 |
+
return probs / np.sum(probs + cls.eps, axis=-1)
|
AlphaCLIP/eval/rec_zs_test/main.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from interpreter import *
|
12 |
+
from executor import *
|
13 |
+
from methods import *
|
14 |
+
|
15 |
+
METHODS_MAP = {
|
16 |
+
"baseline": Baseline,
|
17 |
+
"random": Random,
|
18 |
+
"parse": Parse,
|
19 |
+
}
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument("--input_file", type=str, help="input file with expressions and annotations in jsonlines format")
|
24 |
+
parser.add_argument("--image_root", type=str, help="path to images (train2014 directory of COCO)")
|
25 |
+
parser.add_argument("--clip_model", type=str, default="RN50x16,ViT-B/32", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
|
26 |
+
parser.add_argument("--clip_type", type=str, default="aclip", help="which clip model to use (should use RN50x4, ViT-B/32, or both separated by a comma")
|
27 |
+
parser.add_argument("--albef_path", type=str, default=None, help="to use ALBEF (instead of CLIP), specify the path to the ALBEF checkpoint")
|
28 |
+
parser.add_argument("--method", type=str, default="parse", help="method to solve expressions")
|
29 |
+
parser.add_argument("--box_representation_method", type=str, default="crop,blur", help="method of representing boxes as individual images (crop, blur, or both separated by a comma)")
|
30 |
+
parser.add_argument("--box_method_aggregator", type=str, default="sum", help="method of combining box representation scores")
|
31 |
+
parser.add_argument("--box_area_threshold", type=float, default=0.0, help="minimum area (as a proportion of image area) for a box to be considered as the answer")
|
32 |
+
parser.add_argument("--output_file", type=str, default=None, help="(optional) output path to save results")
|
33 |
+
parser.add_argument("--detector_file", type=str, default=None, help="(optional) file containing object detections. if not provided, the gold object boxes will be used.")
|
34 |
+
parser.add_argument("--mock", action="store_true", help="(optional) mock CLIP execution.")
|
35 |
+
parser.add_argument("--device", type=int, default=0, help="CUDA device to use.")
|
36 |
+
parser.add_argument("--shuffle_words", action="store_true", help="If true, shuffle words in the sentence")
|
37 |
+
parser.add_argument("--gradcam_alpha", type=float, nargs='+', help="alpha value to use for gradcam method")
|
38 |
+
parser.add_argument("--enlarge_boxes", type=float, default=0.0, help="(optional) whether to enlarge boxes when passing them to the model")
|
39 |
+
parser.add_argument("--part", type=str, default=None, help="(optional) specify how many parts to divide the dataset into and which part to run in the format NUM_PARTS,PART_NUM")
|
40 |
+
parser.add_argument("--batch_size", type=int, default=1, help="number of instances to process in one model call (only supported for baseline model)")
|
41 |
+
parser.add_argument("--baseline_head", action="store_true", help="For baseline, controls whether model is called on both full expression and head noun chunk of expression")
|
42 |
+
parser.add_argument("--mdetr", type=str, default=None, help="to use MDETR as the executor model, specify the name of the MDETR model")
|
43 |
+
parser.add_argument("--albef_block_num", type=int, default=8, help="block num for ALBEF gradcam")
|
44 |
+
parser.add_argument("--albef_mode", type=str, choices=["itm", "itc"], default="itm")
|
45 |
+
parser.add_argument("--expand_position_embedding",action="store_true")
|
46 |
+
parser.add_argument("--gradcam_background", action="store_true")
|
47 |
+
parser.add_argument("--mdetr_given_bboxes", action="store_true")
|
48 |
+
parser.add_argument("--mdetr_use_token_mapping", action="store_true")
|
49 |
+
parser.add_argument("--non_square_size", action="store_true")
|
50 |
+
parser.add_argument("--blur_std_dev", type=int, default=100, help="standard deviation of Gaussian blur")
|
51 |
+
parser.add_argument("--gradcam_ensemble_before", action="store_true", help="Average gradcam maps of different models before summing over the maps")
|
52 |
+
parser.add_argument("--cache_path", type=str, default=None, help="cache features")
|
53 |
+
# Arguments related to Parse method.
|
54 |
+
parser.add_argument("--no_rel", action="store_true", help="Disable relation extraction.")
|
55 |
+
parser.add_argument("--no_sup", action="store_true", help="Disable superlative extraction.")
|
56 |
+
parser.add_argument("--no_null", action="store_true", help="Disable null keyword heuristics.")
|
57 |
+
parser.add_argument("--ternary", action="store_true", help="Disable ternary relation extraction.")
|
58 |
+
parser.add_argument("--baseline_threshold", type=float, default=float("inf"), help="(Parse) Threshold to use relations/superlatives.")
|
59 |
+
parser.add_argument("--temperature", type=float, default=1., help="(Parse) Sigmoid temperature.")
|
60 |
+
parser.add_argument("--superlative_head_only", action="store_true", help="(Parse) Superlatives only quanntify head predicate.")
|
61 |
+
parser.add_argument("--sigmoid", action="store_true", help="(Parse) Use sigmoid, not softmax.")
|
62 |
+
parser.add_argument("--no_possessive", action="store_true", help="(Parse) Model extraneous relations as possessive relations.")
|
63 |
+
parser.add_argument("--expand_chunks", action="store_true", help="(Parse) Expand noun chunks to include descendant tokens that aren't ancestors of tokens in other chunks")
|
64 |
+
parser.add_argument("--parse_no_branch", action="store_true", help="(Parse) Only do the parsing procedure if some relation/superlative keyword is in the expression")
|
65 |
+
parser.add_argument("--possessive_no_expand", action="store_true", help="(Parse) Expand ent2 in possessive case")
|
66 |
+
args = parser.parse_args()
|
67 |
+
|
68 |
+
with open(args.input_file) as f:
|
69 |
+
lines = f.readlines()
|
70 |
+
data = [json.loads(line) for line in lines]
|
71 |
+
|
72 |
+
device = f"cuda:{args.device}" if torch.cuda.is_available() and args.device >= 0 else "cpu"
|
73 |
+
gradcam = args.method == "gradcam"
|
74 |
+
|
75 |
+
executor = ClipExecutor(clip_model=args.clip_model, box_representation_method=args.box_representation_method, method_aggregator=args.box_method_aggregator, device=device, square_size=not args.non_square_size, expand_position_embedding=args.expand_position_embedding, blur_std_dev=args.blur_std_dev, cache_path=args.cache_path, input_file=args.input_file, clip_type=args.clip_type)
|
76 |
+
|
77 |
+
method = METHODS_MAP[args.method](args)
|
78 |
+
correct_count = 0
|
79 |
+
total_count = 0
|
80 |
+
if args.output_file:
|
81 |
+
output_file = open(args.output_file, "w")
|
82 |
+
if args.detector_file:
|
83 |
+
detector_file = open(args.detector_file)
|
84 |
+
detections_list = json.load(detector_file)
|
85 |
+
if isinstance(detections_list, dict):
|
86 |
+
detections_map = {int(image_id): detections_list[image_id] for image_id in detections_list}
|
87 |
+
else:
|
88 |
+
detections_map = defaultdict(list)
|
89 |
+
for detection in detections_list:
|
90 |
+
detections_map[detection["image_id"]].append(detection["box"])
|
91 |
+
|
92 |
+
part = 0
|
93 |
+
if args.part is not None: # for multi-gpu test / part-data test
|
94 |
+
num_parts = int(args.part.split(",")[0])
|
95 |
+
part = int(args.part.split(",")[1])
|
96 |
+
data = data[int(len(data)*part/num_parts):int(len(data)*(part+1)/num_parts)]
|
97 |
+
|
98 |
+
batch_count = 0
|
99 |
+
batch_boxes = []
|
100 |
+
batch_gold_boxes = []
|
101 |
+
batch_gold_index = []
|
102 |
+
batch_file_names = []
|
103 |
+
batch_sentences = []
|
104 |
+
for datum in tqdm(data):
|
105 |
+
if "coco" in datum["file_name"].lower():
|
106 |
+
file_name = "_".join(datum["file_name"].split("_")[:-1])+".jpg"
|
107 |
+
else:
|
108 |
+
file_name = datum["file_name"]
|
109 |
+
img_path = os.path.join(args.image_root, file_name)
|
110 |
+
img = Image.open(img_path).convert('RGB')
|
111 |
+
gold_boxes = [Box(x=ann["bbox"][0], y=ann["bbox"][1], w=ann["bbox"][2], h=ann["bbox"][3]) for ann in datum["anns"]]
|
112 |
+
if isinstance(datum["ann_id"], int) or isinstance(datum["ann_id"], str):
|
113 |
+
datum["ann_id"] = [datum["ann_id"]]
|
114 |
+
assert isinstance(datum["ann_id"], list)
|
115 |
+
gold_index = [i for i in range(len(datum["anns"])) if datum["anns"][i]["id"] in datum["ann_id"]]
|
116 |
+
if args.detector_file:
|
117 |
+
boxes = [Box(x=box[0], y=box[1], w=box[2], h=box[3]) for box in detections_map[int(datum["image_id"])]]
|
118 |
+
if len(boxes) == 0:
|
119 |
+
boxes = [Box(x=0, y=0, w=img.width, h=img.height)]
|
120 |
+
else:
|
121 |
+
boxes = gold_boxes
|
122 |
+
for sentence in datum["sentences"]:
|
123 |
+
env = Environment(img, boxes, executor, (args.mdetr is not None and not args.mdetr_given_bboxes), str(datum["image_id"]), img_path)
|
124 |
+
if args.shuffle_words:
|
125 |
+
words = sentence["raw"].lower().split()
|
126 |
+
random.shuffle(words)
|
127 |
+
result = method.execute(" ".join(words), env)
|
128 |
+
else:
|
129 |
+
result = method.execute(sentence["raw"].lower(), env)
|
130 |
+
boxes = env.boxes
|
131 |
+
print(sentence["raw"].lower())
|
132 |
+
correct = False
|
133 |
+
for g_index in gold_index:
|
134 |
+
if iou(boxes[result["pred"]], gold_boxes[g_index]) > 0.5:
|
135 |
+
correct = True
|
136 |
+
break
|
137 |
+
if correct:
|
138 |
+
result["correct"] = 1
|
139 |
+
correct_count += 1
|
140 |
+
else:
|
141 |
+
result["correct"] = 0
|
142 |
+
if args.detector_file:
|
143 |
+
argmax_ious = []
|
144 |
+
max_ious = []
|
145 |
+
for g_index in gold_index:
|
146 |
+
ious = [iou(box, gold_boxes[g_index]) for box in boxes]
|
147 |
+
argmax_iou = -1
|
148 |
+
max_iou = 0
|
149 |
+
if max(ious) >= 0.5:
|
150 |
+
for index, value in enumerate(ious):
|
151 |
+
if value > max_iou:
|
152 |
+
max_iou = value
|
153 |
+
argmax_iou = index
|
154 |
+
argmax_ious.append(argmax_iou)
|
155 |
+
max_ious.append(max_iou)
|
156 |
+
argmax_iou = -1
|
157 |
+
max_iou = 0
|
158 |
+
if max(max_ious) >= 0.5:
|
159 |
+
for index, value in zip(argmax_ious, max_ious):
|
160 |
+
if value > max_iou:
|
161 |
+
max_iou = value
|
162 |
+
argmax_iou = index
|
163 |
+
result["gold_index"] = argmax_iou
|
164 |
+
else:
|
165 |
+
result["gold_index"] = gold_index
|
166 |
+
result["bboxes"] = [[box.left, box.top, box.right, box.bottom] for box in boxes]
|
167 |
+
result["file_name"] = file_name
|
168 |
+
result["probabilities"] = result["probs"]
|
169 |
+
result["text"] = sentence["raw"].lower()
|
170 |
+
if args.output_file:
|
171 |
+
# Serialize numpy arrays for JSON.
|
172 |
+
for key in result:
|
173 |
+
if isinstance(result[key], np.ndarray):
|
174 |
+
result[key] = result[key].tolist()
|
175 |
+
if isinstance(result[key], np.int64):
|
176 |
+
result[key] = result[key].item()
|
177 |
+
output_file.write(json.dumps(result)+"\n")
|
178 |
+
total_count += 1
|
179 |
+
print(f"est_acc: {100 * correct_count / total_count:.3f}")
|
180 |
+
|
181 |
+
if args.output_file:
|
182 |
+
output_file.close()
|
183 |
+
print(f"acc: {100 * correct_count / total_count:.3f}")
|
184 |
+
acc = 100 * correct_count / total_count
|
185 |
+
|
186 |
+
result = {}
|
187 |
+
result['acc'] = acc
|
188 |
+
json.dump(acc, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_acc_' + str(part)+'.json'),'w'))
|
189 |
+
json.dump(str(correct_count)+' '+str(total_count), open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_count_' + str(part)+'.json'),'w'))
|
190 |
+
stats = method.get_stats()
|
191 |
+
if stats:
|
192 |
+
pairs = sorted(list(stats.items()), key=lambda tup: tup[0])
|
193 |
+
for key, value in pairs:
|
194 |
+
result[key] = value
|
195 |
+
if isinstance(value, float):
|
196 |
+
print(f"{key}: {value:.5f}")
|
197 |
+
else:
|
198 |
+
print(f"{key}: {value}")
|
199 |
+
|
200 |
+
json.dump(result, open(os.path.join('./output', args.input_file.split('/')[-1].split('.')[0] + '_' + str(part)+'.json'),'w'))
|
AlphaCLIP/eval/rec_zs_test/methods/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .baseline import Baseline
|
2 |
+
from .random_method import Random
|
3 |
+
from .parse import Parse
|
AlphaCLIP/eval/rec_zs_test/methods/baseline.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A naive baseline method: just pass the full expression to CLIP."""
|
2 |
+
|
3 |
+
from overrides import overrides
|
4 |
+
from typing import Dict, Any, List
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import spacy
|
8 |
+
from argparse import Namespace
|
9 |
+
|
10 |
+
from .ref_method import RefMethod
|
11 |
+
from lattice import Product as L
|
12 |
+
|
13 |
+
|
14 |
+
class Baseline(RefMethod):
|
15 |
+
"""CLIP-only baseline where each box is evaluated with the full expression."""
|
16 |
+
|
17 |
+
nlp = spacy.load('en_core_web_sm')
|
18 |
+
|
19 |
+
def __init__(self, args: Namespace):
|
20 |
+
self.args = args
|
21 |
+
self.box_area_threshold = args.box_area_threshold
|
22 |
+
self.batch_size = args.batch_size
|
23 |
+
self.batch = []
|
24 |
+
|
25 |
+
@overrides
|
26 |
+
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
|
27 |
+
chunk_texts = self.get_chunk_texts(caption)
|
28 |
+
probs = env.filter(caption, area_threshold = self.box_area_threshold, softmax=True)
|
29 |
+
if self.args.baseline_head:
|
30 |
+
probs2 = env.filter(chunk_texts[0], area_threshold = self.box_area_threshold, softmax=True)
|
31 |
+
probs = L.meet(probs, probs2)
|
32 |
+
pred = np.argmax(probs)
|
33 |
+
return {
|
34 |
+
"probs": probs,
|
35 |
+
"pred": pred,
|
36 |
+
"box": env.boxes[pred],
|
37 |
+
}
|
38 |
+
|
39 |
+
def get_chunk_texts(self, expression: str) -> List:
|
40 |
+
doc = self.nlp(expression)
|
41 |
+
head = None
|
42 |
+
for token in doc:
|
43 |
+
if token.head.i == token.i:
|
44 |
+
head = token
|
45 |
+
break
|
46 |
+
head_chunk = None
|
47 |
+
chunk_texts = []
|
48 |
+
for chunk in doc.noun_chunks:
|
49 |
+
if head.i >= chunk.start and head.i < chunk.end:
|
50 |
+
head_chunk = chunk.text
|
51 |
+
chunk_texts.append(chunk.text)
|
52 |
+
if head_chunk is None:
|
53 |
+
if len(list(doc.noun_chunks)) > 0:
|
54 |
+
head_chunk = list(doc.noun_chunks)[0].text
|
55 |
+
else:
|
56 |
+
head_chunk = expression
|
57 |
+
return [head_chunk] + [txt for txt in chunk_texts if txt != head_chunk]
|
AlphaCLIP/eval/rec_zs_test/methods/parse.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Use spatial relations extracted from the parses."""
|
2 |
+
|
3 |
+
from typing import Dict, Any, Callable, List, Tuple, NamedTuple
|
4 |
+
from numbers import Number
|
5 |
+
from collections import defaultdict
|
6 |
+
from overrides import overrides
|
7 |
+
import numpy as np
|
8 |
+
import spacy
|
9 |
+
from spacy.tokens.token import Token
|
10 |
+
from spacy.tokens.span import Span
|
11 |
+
from argparse import Namespace
|
12 |
+
|
13 |
+
from .ref_method import RefMethod
|
14 |
+
from lattice import Product as L
|
15 |
+
from heuristics import Heuristics
|
16 |
+
from entity_extraction import Entity, expand_chunks
|
17 |
+
|
18 |
+
|
19 |
+
def get_conjunct(ent, chunks, heuristics: Heuristics) -> Entity:
|
20 |
+
"""If an entity represents a conjunction of two entities, pull them apart."""
|
21 |
+
head = ent.head.root # Not ...root.head. Confusing names here.
|
22 |
+
if not any(child.text == "and" for child in head.children):
|
23 |
+
return None
|
24 |
+
for child in head.children:
|
25 |
+
if child.i in chunks and head.i is not child.i:
|
26 |
+
return Entity.extract(child, chunks, heuristics)
|
27 |
+
return None
|
28 |
+
|
29 |
+
|
30 |
+
class Parse(RefMethod):
|
31 |
+
"""An REF method that extracts and composes predicates, relations, and superlatives from a dependency parse.
|
32 |
+
|
33 |
+
The process is as follows:
|
34 |
+
1. Use spacy to parse the document.
|
35 |
+
2. Extract a semantic entity tree from the parse.
|
36 |
+
3. Execute the entity tree to yield a distribution over boxes."""
|
37 |
+
|
38 |
+
nlp = spacy.load('en_core_web_sm')
|
39 |
+
|
40 |
+
def __init__(self, args: Namespace = None):
|
41 |
+
self.args = args
|
42 |
+
self.box_area_threshold = args.box_area_threshold
|
43 |
+
self.baseline_threshold = args.baseline_threshold
|
44 |
+
self.temperature = args.temperature
|
45 |
+
self.superlative_head_only = args.superlative_head_only
|
46 |
+
self.expand_chunks = args.expand_chunks
|
47 |
+
self.branch = not args.parse_no_branch
|
48 |
+
self.possessive_expand = not args.possessive_no_expand
|
49 |
+
|
50 |
+
# Lists of keyword heuristics to use.
|
51 |
+
self.heuristics = Heuristics(args)
|
52 |
+
|
53 |
+
# Metrics for debugging relation extraction behavor.
|
54 |
+
self.counts = defaultdict(int)
|
55 |
+
|
56 |
+
@overrides
|
57 |
+
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
|
58 |
+
"""Construct an `Entity` tree from the parse and execute it to yield a distribution over boxes."""
|
59 |
+
# Start by using the full caption, as in Baseline.
|
60 |
+
probs = env.filter(caption, area_threshold=self.box_area_threshold, softmax=True)
|
61 |
+
ori_probs = probs
|
62 |
+
|
63 |
+
# Extend the baseline using parse stuff.
|
64 |
+
doc = self.nlp(caption)
|
65 |
+
head = self.get_head(doc)
|
66 |
+
chunks = self.get_chunks(doc)
|
67 |
+
if self.expand_chunks:
|
68 |
+
chunks = expand_chunks(doc, chunks)
|
69 |
+
entity = Entity.extract(head, chunks, self.heuristics)
|
70 |
+
|
71 |
+
# If no head noun is found, take the first one.
|
72 |
+
if entity is None and len(list(doc.noun_chunks)) > 0:
|
73 |
+
head = list(doc.noun_chunks)[0]
|
74 |
+
entity = Entity.extract(head.root.head, chunks, self.heuristics)
|
75 |
+
self.counts["n_0th_noun"] += 1
|
76 |
+
|
77 |
+
# If we have found some head noun, filter based on it.
|
78 |
+
if entity is not None and (any(any(token.text in h.keywords for h in self.heuristics.relations+self.heuristics.superlatives) for token in doc) or not self.branch):
|
79 |
+
ent_probs, texts = self.execute_entity(entity, env, chunks)
|
80 |
+
probs = L.meet(probs, ent_probs)
|
81 |
+
else:
|
82 |
+
texts = [caption]
|
83 |
+
self.counts["n_full_expr"] += 1
|
84 |
+
|
85 |
+
if len(ori_probs) == 1:
|
86 |
+
probs = ori_probs
|
87 |
+
|
88 |
+
self.counts["n_total"] += 1
|
89 |
+
pred = np.argmax(probs)
|
90 |
+
return {
|
91 |
+
"probs": probs,
|
92 |
+
"pred": pred,
|
93 |
+
"box": env.boxes[pred],
|
94 |
+
"texts": texts
|
95 |
+
}
|
96 |
+
|
97 |
+
def execute_entity(self,
|
98 |
+
ent: Entity,
|
99 |
+
env: "Environment",
|
100 |
+
chunks: Dict[int, Span],
|
101 |
+
root: bool = True,
|
102 |
+
) -> np.ndarray:
|
103 |
+
"""Execute an `Entity` tree recursively, yielding a distribution over boxes."""
|
104 |
+
self.counts["n_rec"] += 1
|
105 |
+
probs = [1, 1]
|
106 |
+
head_probs = probs
|
107 |
+
|
108 |
+
# Only use relations if the head baseline isn't certain.
|
109 |
+
if len(probs) == 1 or len(env.boxes) == 1:
|
110 |
+
return probs, [ent.text]
|
111 |
+
|
112 |
+
m1, m2 = probs[:2] # probs[(-probs).argsort()[:2]]
|
113 |
+
text = ent.text
|
114 |
+
rel_probs = []
|
115 |
+
if self.baseline_threshold == float("inf") or m1 < self.baseline_threshold * m2:
|
116 |
+
self.counts["n_rec_rel"] += 1
|
117 |
+
for tokens, ent2 in ent.relations:
|
118 |
+
self.counts["n_rel"] += 1
|
119 |
+
rel = None
|
120 |
+
# Heuristically decide which spatial relation is represented.
|
121 |
+
for heuristic in self.heuristics.relations:
|
122 |
+
if any(tok.text in heuristic.keywords for tok in tokens):
|
123 |
+
rel = heuristic.callback(env)
|
124 |
+
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
|
125 |
+
break
|
126 |
+
# Filter and normalize by the spatial relation.
|
127 |
+
if rel is not None:
|
128 |
+
probs2 = self.execute_entity(ent2, env, chunks, root=False)
|
129 |
+
events = L.meet(np.expand_dims(probs2, axis=0), rel)
|
130 |
+
new_probs = L.join_reduce(events)
|
131 |
+
rel_probs.append((ent2.text, new_probs, probs2))
|
132 |
+
continue
|
133 |
+
|
134 |
+
# This case specifically handles "between", which takes two noun arguments.
|
135 |
+
rel = None
|
136 |
+
for heuristic in self.heuristics.ternary_relations:
|
137 |
+
if any(tok.text in heuristic.keywords for tok in tokens):
|
138 |
+
rel = heuristic.callback(env)
|
139 |
+
self.counts[f"n_rel_{heuristic.keywords[0]}"] += 1
|
140 |
+
break
|
141 |
+
if rel is not None:
|
142 |
+
ent3 = get_conjunct(ent2, chunks, self.heuristics)
|
143 |
+
if ent3 is not None:
|
144 |
+
probs2 = self.execute_entity(ent2, env, chunks, root=False)
|
145 |
+
probs2 = np.expand_dims(probs2, axis=[0, 2])
|
146 |
+
probs3 = self.execute_entity(ent3, env, chunks, root=False)
|
147 |
+
probs3 = np.expand_dims(probs3, axis=[0, 1])
|
148 |
+
events = L.meet(L.meet(probs2, probs3), rel)
|
149 |
+
new_probs = L.join_reduce(L.join_reduce(events))
|
150 |
+
probs = L.meet(probs, new_probs)
|
151 |
+
continue
|
152 |
+
# Otherwise, treat the relation as a possessive relation.
|
153 |
+
if not self.args.no_possessive:
|
154 |
+
if self.possessive_expand:
|
155 |
+
text = ent.expand(ent2.head)
|
156 |
+
else:
|
157 |
+
text += f' {" ".join(tok.text for tok in tokens)} {ent2.text}'
|
158 |
+
#poss_probs = self._filter(text, env, root=root, expand=.3)
|
159 |
+
probs = self._filter(text, env, root=root)
|
160 |
+
texts = [text]
|
161 |
+
return_probs = [(probs.tolist(), probs.tolist())]
|
162 |
+
for (ent2_text, new_probs, ent2_only_probs) in rel_probs:
|
163 |
+
probs = L.meet(probs, new_probs)
|
164 |
+
probs /= probs.sum()
|
165 |
+
texts.append(ent2_text)
|
166 |
+
return_probs.append((probs.tolist(), ent2_only_probs.tolist()))
|
167 |
+
|
168 |
+
# Only use superlatives if thresholds work out.
|
169 |
+
m1, m2 = probs[(-probs).argsort()[:2]]
|
170 |
+
if m1 < self.baseline_threshold * m2:
|
171 |
+
self.counts["n_rec_sup"] += 1
|
172 |
+
for tokens in ent.superlatives:
|
173 |
+
self.counts["n_sup"] += 1
|
174 |
+
sup = None
|
175 |
+
for heuristic_index, heuristic in enumerate(self.heuristics.superlatives):
|
176 |
+
if any(tok.text in heuristic.keywords for tok in tokens):
|
177 |
+
texts.append('sup:'+' '.join([tok.text for tok in tokens if tok.text in heuristic.keywords]))
|
178 |
+
sup = heuristic.callback(env)
|
179 |
+
self.counts[f"n_sup_{heuristic.keywords[0]}"] += 1
|
180 |
+
break
|
181 |
+
if sup is not None:
|
182 |
+
# Could use `probs` or `head_probs` here?
|
183 |
+
precond = head_probs if self.superlative_head_only else probs
|
184 |
+
probs = L.meet(np.expand_dims(precond, axis=1)*np.expand_dims(precond, axis=0), sup).sum(axis=1)
|
185 |
+
probs = probs / probs.sum()
|
186 |
+
return_probs.append((probs.tolist(), None))
|
187 |
+
|
188 |
+
if root:
|
189 |
+
assert len(texts) == len(return_probs)
|
190 |
+
return probs, (texts, return_probs, tuple(str(chunk) for chunk in chunks.values()))
|
191 |
+
return probs
|
192 |
+
|
193 |
+
def get_head(self, doc) -> Token:
|
194 |
+
"""Return the token that is the head of the dependency parse. """
|
195 |
+
for token in doc:
|
196 |
+
if token.head.i == token.i:
|
197 |
+
return token
|
198 |
+
return None
|
199 |
+
|
200 |
+
def get_chunks(self, doc) -> Dict[int, Any]:
|
201 |
+
"""Return a dictionary mapping sentence indices to their noun chunk."""
|
202 |
+
chunks = {}
|
203 |
+
for chunk in doc.noun_chunks:
|
204 |
+
for idx in range(chunk.start, chunk.end):
|
205 |
+
chunks[idx] = chunk
|
206 |
+
return chunks
|
207 |
+
|
208 |
+
@overrides
|
209 |
+
def get_stats(self) -> Dict[str, Number]:
|
210 |
+
"""Summary statistics that have been tracked on this object."""
|
211 |
+
stats = dict(self.counts)
|
212 |
+
n_rel_caught = sum(v for k, v in stats.items() if k.startswith("n_rel_"))
|
213 |
+
n_sup_caught = sum(v for k, v in stats.items() if k.startswith("n_sup_"))
|
214 |
+
stats.update({
|
215 |
+
"p_rel_caught": n_rel_caught / (self.counts["n_rel"] + 1e-9),
|
216 |
+
"p_sup_caught": n_sup_caught / (self.counts["n_sup"] + 1e-9),
|
217 |
+
"p_rec_rel": self.counts["n_rec_rel"] / (self.counts["n_rec"] + 1e-9),
|
218 |
+
"p_rec_sup": self.counts["n_rec_sup"] / (self.counts["n_rec"] + 1e-9),
|
219 |
+
"p_0th_noun": self.counts["n_0th_noun"] / (self.counts["n_total"] + 1e-9),
|
220 |
+
"p_full_expr": self.counts["n_full_expr"] / (self.counts["n_total"] + 1e-9),
|
221 |
+
"avg_rec": self.counts["n_rec"] / self.counts["n_total"],
|
222 |
+
})
|
223 |
+
return stats
|
224 |
+
|
225 |
+
def _filter(self,
|
226 |
+
caption: str,
|
227 |
+
env: "Environment",
|
228 |
+
root: bool = False,
|
229 |
+
expand: float = None,
|
230 |
+
) -> np.ndarray:
|
231 |
+
"""Wrap a filter call in a consistent way for all recursions."""
|
232 |
+
kwargs = {
|
233 |
+
"softmax": not self.args.sigmoid,
|
234 |
+
"temperature": self.args.temperature,
|
235 |
+
}
|
236 |
+
if root:
|
237 |
+
return env.filter(caption, area_threshold=self.box_area_threshold, **kwargs)
|
238 |
+
else:
|
239 |
+
return env.filter(caption, **kwargs)
|
AlphaCLIP/eval/rec_zs_test/methods/random_method.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A naive baseline method: just pass the full expression to CLIP."""
|
2 |
+
|
3 |
+
from overrides import overrides
|
4 |
+
from typing import Dict, Any
|
5 |
+
import random
|
6 |
+
from argparse import Namespace
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .ref_method import RefMethod
|
11 |
+
|
12 |
+
|
13 |
+
class Random(RefMethod):
|
14 |
+
"""CLIP-only baseline where each box is evaluated with the full expression."""
|
15 |
+
|
16 |
+
def __init__(self, args: Namespace):
|
17 |
+
self.box_area_threshold = args.box_area_threshold
|
18 |
+
|
19 |
+
@overrides
|
20 |
+
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
|
21 |
+
probs = env.filter_area(self.box_area_threshold)*env.uniform()
|
22 |
+
random_ordering = list(range(len(env.boxes)))
|
23 |
+
random.shuffle(random_ordering)
|
24 |
+
random_ordering = np.array(random_ordering)
|
25 |
+
pred = np.argmax(probs*random_ordering)
|
26 |
+
return {
|
27 |
+
"probs": probs.tolist(),
|
28 |
+
"pred": int(pred),
|
29 |
+
"text": caption.lower()
|
30 |
+
}
|
AlphaCLIP/eval/rec_zs_test/methods/ref_method.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base class for a method for doing referring expressions."""
|
2 |
+
|
3 |
+
from typing import Dict, Any
|
4 |
+
from abc import ABCMeta, abstractmethod
|
5 |
+
|
6 |
+
|
7 |
+
class RefMethod(metaclass=ABCMeta):
|
8 |
+
@abstractmethod
|
9 |
+
def execute(self, caption: str, env: "Environment") -> Dict[str, Any]:
|
10 |
+
return NotImplemented
|
11 |
+
|
12 |
+
def get_stats(self) -> Dict[str, Any]:
|
13 |
+
return {}
|
AlphaCLIP/eval/rec_zs_test/output/.gitkeep
ADDED
File without changes
|
AlphaCLIP/eval/rec_zs_test/requirements.txt
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
attrs==21.2.0
|
2 |
+
blis==0.7.4
|
3 |
+
catalogue==2.0.4
|
4 |
+
certifi==2021.5.30
|
5 |
+
chardet==4.0.0
|
6 |
+
click==7.1.2
|
7 |
+
cymem==2.0.5
|
8 |
+
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0-py3-none-any.whl
|
9 |
+
filelock==3.0.12
|
10 |
+
ftfy==6.0.3
|
11 |
+
huggingface-hub==0.0.12
|
12 |
+
idna==2.10
|
13 |
+
iniconfig==1.1.1
|
14 |
+
itsdangerous==2.0.1
|
15 |
+
joblib==1.0.1
|
16 |
+
MarkupSafe==2.0.1
|
17 |
+
murmurhash==1.0.5
|
18 |
+
numpy==1.21.0
|
19 |
+
overrides==6.1.0
|
20 |
+
packaging==21.0
|
21 |
+
pathy==0.6.0
|
22 |
+
Pillow==8.2.0
|
23 |
+
pluggy==0.13.1
|
24 |
+
preshed==3.0.5
|
25 |
+
py==1.10.0
|
26 |
+
pydantic==1.7.4
|
27 |
+
pyparsing==2.4.7
|
28 |
+
pytest==6.2.4
|
29 |
+
PyYAML==5.4.1
|
30 |
+
regex==2021.7.6
|
31 |
+
requests==2.25.1
|
32 |
+
ruamel.yaml==0.17.10
|
33 |
+
ruamel.yaml.clib==0.2.6
|
34 |
+
sacremoses==0.0.45
|
35 |
+
scipy==1.7.0
|
36 |
+
six==1.16.0
|
37 |
+
smart-open==5.1.0
|
38 |
+
spacy==3.0.6
|
39 |
+
spacy-legacy==3.0.7
|
40 |
+
srsly==2.4.1
|
41 |
+
thinc==8.0.7
|
42 |
+
timm==0.4.12
|
43 |
+
tokenizers==0.10.3
|
44 |
+
toml==0.10.2
|
45 |
+
tqdm==4.61.2
|
46 |
+
transformers==4.9.0
|
47 |
+
typer==0.3.2
|
48 |
+
typing-extensions==3.10.0.0
|
49 |
+
typing-utils==0.1.0
|
50 |
+
urllib3==1.26.6
|
51 |
+
wasabi==0.8.2
|
52 |
+
wcwidth==0.2.5
|
53 |
+
Werkzeug==2.0.1
|
AlphaCLIP/eval/rec_zs_test/run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_representation_method full,blur --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache
|
AlphaCLIP/eval/rec_zs_test/run_multi_gpus.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CUDA_VISIBLE_DEVICES=0 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,0" &
|
2 |
+
|
3 |
+
CUDA_VISIBLE_DEVICES=1 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,1" &
|
4 |
+
|
5 |
+
CUDA_VISIBLE_DEVICES=2 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,2" &
|
6 |
+
|
7 |
+
CUDA_VISIBLE_DEVICES=3 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,3" &
|
8 |
+
|
9 |
+
CUDA_VISIBLE_DEVICES=4 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,4" &
|
10 |
+
|
11 |
+
CUDA_VISIBLE_DEVICES=5 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,5" &
|
12 |
+
|
13 |
+
CUDA_VISIBLE_DEVICES=6 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,6" &
|
14 |
+
|
15 |
+
CUDA_VISIBLE_DEVICES=7 python main.py --input_file reclip_data/refcoco_val.jsonl --image_root ./data/train2014 --method parse --gradcam_alpha 0.5 0.5 --box_method_aggregator sum --clip_model ViT-B/16,ViT-L/14 --box_representation_method full,blur --detector_file reclip_data/refcoco_dets_dict.json --cache_path ./cache --part "8,7"
|
AlphaCLIP/hubconf.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from alpha_clip.alpha_clip import tokenize as _tokenize, load as _load, available_models as _available_models
|
2 |
+
import re
|
3 |
+
import string
|
4 |
+
|
5 |
+
dependencies = ["torch", "torchvision", "ftfy", "regex", "tqdm"]
|
6 |
+
|
7 |
+
# For compatibility (cannot include special characters in function name)
|
8 |
+
model_functions = { model: re.sub(f'[{string.punctuation}]', '_', model) for model in _available_models()}
|
9 |
+
|
10 |
+
def _create_hub_entrypoint(model):
|
11 |
+
def entrypoint(**kwargs):
|
12 |
+
return _load(model, **kwargs)
|
13 |
+
|
14 |
+
entrypoint.__doc__ = f"""Loads the {model} CLIP model
|
15 |
+
|
16 |
+
Parameters
|
17 |
+
----------
|
18 |
+
device : Union[str, torch.device]
|
19 |
+
The device to put the loaded model
|
20 |
+
|
21 |
+
jit : bool
|
22 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
23 |
+
|
24 |
+
download_root: str
|
25 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
26 |
+
|
27 |
+
Returns
|
28 |
+
-------
|
29 |
+
model : torch.nn.Module
|
30 |
+
The {model} CLIP model
|
31 |
+
|
32 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
33 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
34 |
+
"""
|
35 |
+
return entrypoint
|
36 |
+
|
37 |
+
def tokenize():
|
38 |
+
return _tokenize
|
39 |
+
|
40 |
+
_entrypoints = {model_functions[model]: _create_hub_entrypoint(model) for model in _available_models()}
|
41 |
+
|
42 |
+
globals().update(_entrypoints)
|
AlphaCLIP/requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ftfy
|
2 |
+
regex
|
3 |
+
tqdm
|
4 |
+
torch
|
5 |
+
torchvision
|
AlphaCLIP/setup.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pkg_resources
|
4 |
+
from setuptools import setup, find_packages
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name="alpha_clip",
|
8 |
+
py_modules=["alpha_clip"],
|
9 |
+
version="1.0",
|
10 |
+
description="",
|
11 |
+
author="OpenAI&ZeyiSun",
|
12 |
+
packages=find_packages(exclude=["tests*"]),
|
13 |
+
install_requires=[
|
14 |
+
str(r)
|
15 |
+
for r in pkg_resources.parse_requirements(
|
16 |
+
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
17 |
+
)
|
18 |
+
],
|
19 |
+
include_package_data=True,
|
20 |
+
extras_require={'dev': ['pytest']},
|
21 |
+
)
|
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🏢
|
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
|
|
4 |
colorFrom: green
|
5 |
colorTo: red
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 3.48.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: mit
|
app.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from PIL import Image
|
6 |
+
from diffusers import StableDiffusionInpaintPipeline
|
7 |
+
from model.clip_away import CLIPAway
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
# Parse command line arguments
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--config", type=str, default="config/inference_config.yaml", help="Path to the config file")
|
15 |
+
parser.add_argument("--share", action="store_true", help="Share the interface if provided")
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
# Load configuration and models
|
19 |
+
config = OmegaConf.load(args.config)
|
20 |
+
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
21 |
+
"runwayml/stable-diffusion-inpainting", safety_checker=None, torch_dtype=torch.float32
|
22 |
+
)
|
23 |
+
clipaway = CLIPAway(
|
24 |
+
sd_pipe=sd_pipeline,
|
25 |
+
image_encoder_path=config.image_encoder_path,
|
26 |
+
ip_ckpt=config.ip_adapter_ckpt_path,
|
27 |
+
alpha_clip_path=config.alpha_clip_ckpt_pth,
|
28 |
+
config=config,
|
29 |
+
alpha_clip_id=config.alpha_clip_id,
|
30 |
+
device=config.device,
|
31 |
+
num_tokens=4
|
32 |
+
)
|
33 |
+
|
34 |
+
def dilate_mask(mask, kernel_size=5, iterations=5):
|
35 |
+
mask = mask.convert("L")
|
36 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
37 |
+
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
|
38 |
+
return Image.fromarray(mask)
|
39 |
+
|
40 |
+
def combine_masks(uploaded_mask, sketched_mask):
|
41 |
+
if uploaded_mask is not None:
|
42 |
+
return uploaded_mask
|
43 |
+
elif sketched_mask is not None:
|
44 |
+
return sketched_mask
|
45 |
+
else:
|
46 |
+
raise ValueError("Please provide a mask")
|
47 |
+
|
48 |
+
def remove_obj(image, uploaded_mask, seed):
|
49 |
+
image_pil, sketched_mask = image["image"], image["mask"]
|
50 |
+
mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
|
51 |
+
seed = int(seed)
|
52 |
+
latents = torch.randn((1, 4, 64, 64), generator=torch.Generator().manual_seed(seed)).to("cuda")
|
53 |
+
final_image = clipaway.generate(
|
54 |
+
prompt=[""], scale=1, seed=seed,
|
55 |
+
pil_image=[image_pil], alpha=[mask], strength=1, latents=latents
|
56 |
+
)[0]
|
57 |
+
return final_image
|
58 |
+
|
59 |
+
# Define example data
|
60 |
+
examples = [
|
61 |
+
["assets/gradio_examples/images/1.jpg", "assets/gradio_examples/masks/1.png", 42],
|
62 |
+
["assets/gradio_examples/images/2.jpg", "assets/gradio_examples/masks/2.png", 42],
|
63 |
+
["assets/gradio_examples/images/3.jpg", "assets/gradio_examples/masks/3.png", 464],
|
64 |
+
["assets/gradio_examples/images/4.jpg", "assets/gradio_examples/masks/4.png", 2024],
|
65 |
+
]
|
66 |
+
|
67 |
+
# Define the Gradio interface
|
68 |
+
with gr.Blocks() as demo:
|
69 |
+
gr.Markdown("<h1 style='text-align:center'>CLIPAway: Harmonizing Focused Embeddings for Removing Objects via Diffusion Models</h1>")
|
70 |
+
gr.Markdown("""
|
71 |
+
<div style='display:flex; justify-content:center; align-items:center;'>
|
72 |
+
<a href='https://arxiv.org/abs/2406.09368' style="margin:10px;">Paper</a> |
|
73 |
+
<a href='https://yigitekin.github.io/CLIPAway/' style="margin:10px;">Project Website</a> |
|
74 |
+
<a href='https://github.com/YigitEkin/CLIPAway' style="margin:10px;">GitHub</a>
|
75 |
+
</div>
|
76 |
+
""")
|
77 |
+
gr.Markdown("""
|
78 |
+
This application allows you to remove objects from images using the CLIPAway method with diffusion models.
|
79 |
+
To use this tool:
|
80 |
+
1. Upload an image.
|
81 |
+
2. Either Sketch a mask over the object you want to remove or upload a pre-defined mask if you have one.
|
82 |
+
4. Set the seed for reproducibility (default is 42).
|
83 |
+
5. Click 'Remove Object' to process the image.
|
84 |
+
6. The result will be displayed on the right side.
|
85 |
+
Note: The mask should be a binary image where the object to be removed is white and the background is black.
|
86 |
+
""")
|
87 |
+
|
88 |
+
with gr.Row():
|
89 |
+
with gr.Column():
|
90 |
+
image_input = gr.Image(label="Upload Image and Sketch Mask", type="pil", tool="sketch")
|
91 |
+
uploaded_mask = gr.Image(label="Upload Mask (Optional)", type="pil", optional=True)
|
92 |
+
seed_input = gr.Number(value=42, label="Seed")
|
93 |
+
process_button = gr.Button("Remove Object")
|
94 |
+
with gr.Column():
|
95 |
+
result_image = gr.Image(label="Result")
|
96 |
+
|
97 |
+
process_button.click(
|
98 |
+
fn=remove_obj,
|
99 |
+
inputs=[image_input, uploaded_mask, seed_input],
|
100 |
+
outputs=result_image
|
101 |
+
)
|
102 |
+
|
103 |
+
gr.Examples(
|
104 |
+
examples=examples,
|
105 |
+
inputs=[image_input, uploaded_mask, seed_input],
|
106 |
+
outputs=result_image
|
107 |
+
)
|
108 |
+
|
109 |
+
# Launch the interface with caching
|
110 |
+
if args.share:
|
111 |
+
demo.launch(share=True)
|
112 |
+
else:
|
113 |
+
demo.launch()
|
clip_l14_grit+mim_fultune_6xe.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a5f3f2e24459e9764d9f4b4c053fb354dc9d508bd8f647b952402d6860bc9c3d
|
3 |
+
size 1216760175
|
config/inference_config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
device: "cuda"
|
2 |
+
root_path: assets/gradio_examples
|
3 |
+
image_encoder_path: image_encoder
|
4 |
+
alpha_clip_ckpt_pth: clip_l14_grit+mim_fultune_6xe.pth
|
5 |
+
alpha_clip_id: ViT-L/14
|
6 |
+
ip_adapter_ckpt_path: ip-adapter_sd15.bin
|
7 |
+
sd_model_key: "runwayml/stable-diffusion-inpainting"
|
8 |
+
number_of_hidden_layers: 6
|
9 |
+
alpha_clip_embed_dim: 768
|
10 |
+
ip_adapter_embed_dim: 1024
|
11 |
+
mlp_projection_layer_ckpt_path: model.safetensors
|
12 |
+
save_path_prefix: test/results
|
13 |
+
seed: 42
|
14 |
+
scale: 1
|
15 |
+
strength: 1
|
16 |
+
display_focused_embeds: True
|
image_encoder/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "./image_encoder",
|
3 |
+
"architectures": [
|
4 |
+
"CLIPVisionModelWithProjection"
|
5 |
+
],
|
6 |
+
"attention_dropout": 0.0,
|
7 |
+
"dropout": 0.0,
|
8 |
+
"hidden_act": "gelu",
|
9 |
+
"hidden_size": 1280,
|
10 |
+
"image_size": 224,
|
11 |
+
"initializer_factor": 1.0,
|
12 |
+
"initializer_range": 0.02,
|
13 |
+
"intermediate_size": 5120,
|
14 |
+
"layer_norm_eps": 1e-05,
|
15 |
+
"model_type": "clip_vision_model",
|
16 |
+
"num_attention_heads": 16,
|
17 |
+
"num_channels": 3,
|
18 |
+
"num_hidden_layers": 32,
|
19 |
+
"patch_size": 14,
|
20 |
+
"projection_dim": 1024,
|
21 |
+
"torch_dtype": "float16",
|
22 |
+
"transformers_version": "4.28.0.dev0"
|
23 |
+
}
|
image_encoder/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3d3ec1e66737f77a4f3bc2df3c52eacefc69ce7825e2784183b1d4e9877d9193
|
3 |
+
size 2528481905
|
ip-adapter_sd15.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:68e1df30d760f280e578c302f1e73b37ea08654eff16a31153588047affe0058
|
3 |
+
size 44642825
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ade94c0505170a7698afe8ad4b4fb2307d06f67917b877cf1fd694a43cd6e335
|
3 |
+
size 22877152
|
model/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .clip_away import CLIPAway
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
"CLIPAway"
|
5 |
+
]
|
model/attention_processor.py
ADDED
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/attention_processor.py
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from einops import rearrange, repeat
|
8 |
+
|
9 |
+
|
10 |
+
class AttnProcessor(nn.Module):
|
11 |
+
r"""
|
12 |
+
Default processor for performing attention-related computations.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
hidden_size=None,
|
18 |
+
cross_attention_dim=None,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
def __call__(
|
23 |
+
self,
|
24 |
+
attn,
|
25 |
+
hidden_states,
|
26 |
+
encoder_hidden_states=None,
|
27 |
+
attention_mask=None,
|
28 |
+
temb=None,
|
29 |
+
):
|
30 |
+
residual = hidden_states
|
31 |
+
|
32 |
+
if attn.spatial_norm is not None:
|
33 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
34 |
+
|
35 |
+
input_ndim = hidden_states.ndim
|
36 |
+
|
37 |
+
if input_ndim == 4:
|
38 |
+
batch_size, channel, height, width = hidden_states.shape
|
39 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
40 |
+
|
41 |
+
batch_size, sequence_length, _ = (
|
42 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
43 |
+
)
|
44 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
45 |
+
|
46 |
+
if attn.group_norm is not None:
|
47 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
48 |
+
|
49 |
+
query = attn.to_q(hidden_states)
|
50 |
+
|
51 |
+
if encoder_hidden_states is None:
|
52 |
+
encoder_hidden_states = hidden_states
|
53 |
+
elif attn.norm_cross:
|
54 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
55 |
+
|
56 |
+
key = attn.to_k(encoder_hidden_states)
|
57 |
+
value = attn.to_v(encoder_hidden_states)
|
58 |
+
|
59 |
+
query = attn.head_to_batch_dim(query)
|
60 |
+
key = attn.head_to_batch_dim(key)
|
61 |
+
value = attn.head_to_batch_dim(value)
|
62 |
+
|
63 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
64 |
+
hidden_states = torch.bmm(attention_probs, value)
|
65 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
66 |
+
|
67 |
+
# linear proj
|
68 |
+
hidden_states = attn.to_out[0](hidden_states)
|
69 |
+
# dropout
|
70 |
+
hidden_states = attn.to_out[1](hidden_states)
|
71 |
+
|
72 |
+
if input_ndim == 4:
|
73 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
74 |
+
|
75 |
+
if attn.residual_connection:
|
76 |
+
hidden_states = hidden_states + residual
|
77 |
+
|
78 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
79 |
+
|
80 |
+
return hidden_states
|
81 |
+
|
82 |
+
|
83 |
+
class IPAttnProcessor(nn.Module):
|
84 |
+
r"""
|
85 |
+
Attention processor for IP-Adapater.
|
86 |
+
Args:
|
87 |
+
hidden_size (`int`):
|
88 |
+
The hidden size of the attention layer.
|
89 |
+
cross_attention_dim (`int`):
|
90 |
+
The number of channels in the `encoder_hidden_states`.
|
91 |
+
scale (`float`, defaults to 1.0):
|
92 |
+
the weight scale of image prompt.
|
93 |
+
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
|
94 |
+
The context length of the image features.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
self.hidden_size = hidden_size
|
101 |
+
self.cross_attention_dim = cross_attention_dim
|
102 |
+
self.scale = scale
|
103 |
+
self.num_tokens = num_tokens
|
104 |
+
|
105 |
+
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
106 |
+
self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
107 |
+
|
108 |
+
def __call__(
|
109 |
+
self,
|
110 |
+
attn,
|
111 |
+
hidden_states,
|
112 |
+
encoder_hidden_states=None,
|
113 |
+
attention_mask=None,
|
114 |
+
temb=None,
|
115 |
+
):
|
116 |
+
residual = hidden_states
|
117 |
+
|
118 |
+
if attn.spatial_norm is not None:
|
119 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
120 |
+
|
121 |
+
input_ndim = hidden_states.ndim
|
122 |
+
|
123 |
+
if input_ndim == 4:
|
124 |
+
batch_size, channel, height, width = hidden_states.shape
|
125 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
126 |
+
|
127 |
+
batch_size, sequence_length, _ = (
|
128 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
129 |
+
)
|
130 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
131 |
+
|
132 |
+
if attn.group_norm is not None:
|
133 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
134 |
+
|
135 |
+
query = attn.to_q(hidden_states)
|
136 |
+
|
137 |
+
if encoder_hidden_states is None:
|
138 |
+
encoder_hidden_states = hidden_states
|
139 |
+
else:
|
140 |
+
# get encoder_hidden_states, ip_hidden_states
|
141 |
+
end_pos = encoder_hidden_states.shape[1] - self.num_tokens
|
142 |
+
encoder_hidden_states, ip_hidden_states = (
|
143 |
+
encoder_hidden_states[:, :end_pos, :],
|
144 |
+
encoder_hidden_states[:, end_pos:, :],
|
145 |
+
)
|
146 |
+
if attn.norm_cross:
|
147 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
148 |
+
|
149 |
+
key = attn.to_k(encoder_hidden_states)
|
150 |
+
value = attn.to_v(encoder_hidden_states)
|
151 |
+
|
152 |
+
query = attn.head_to_batch_dim(query)
|
153 |
+
key = attn.head_to_batch_dim(key)
|
154 |
+
value = attn.head_to_batch_dim(value)
|
155 |
+
|
156 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
157 |
+
#!MASK HERE
|
158 |
+
hidden_states = torch.bmm(attention_probs, value)
|
159 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
160 |
+
|
161 |
+
# for ip-adapter
|
162 |
+
ip_key = self.to_k_ip(ip_hidden_states)
|
163 |
+
ip_value = self.to_v_ip(ip_hidden_states)
|
164 |
+
|
165 |
+
ip_key = attn.head_to_batch_dim(ip_key)
|
166 |
+
ip_value = attn.head_to_batch_dim(ip_value)
|
167 |
+
|
168 |
+
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
169 |
+
#!MASK HERE
|
170 |
+
self.attn_map = ip_attention_probs
|
171 |
+
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
172 |
+
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
173 |
+
|
174 |
+
hidden_states = hidden_states + self.scale * ip_hidden_states
|
175 |
+
|
176 |
+
# linear proj
|
177 |
+
hidden_states = attn.to_out[0](hidden_states)
|
178 |
+
# dropout
|
179 |
+
hidden_states = attn.to_out[1](hidden_states)
|
180 |
+
|
181 |
+
if input_ndim == 4:
|
182 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
183 |
+
|
184 |
+
if attn.residual_connection:
|
185 |
+
hidden_states = hidden_states + residual
|
186 |
+
|
187 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
188 |
+
|
189 |
+
return hidden_states
|
model/clip_away.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
modified from from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
from typing import List
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import transforms
|
9 |
+
from transformers import CLIPVisionModelWithProjection
|
10 |
+
import alpha_clip
|
11 |
+
from .utils import get_generator
|
12 |
+
from .attention_processor import AttnProcessor, IPAttnProcessor
|
13 |
+
from safetensors import safe_open
|
14 |
+
from safetensors.torch import load_model
|
15 |
+
import numpy as np
|
16 |
+
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class ImageProjModel(torch.nn.Module):
|
21 |
+
"""Projection Model of IP-Adapter"""
|
22 |
+
|
23 |
+
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.generator = None
|
27 |
+
self.cross_attention_dim = cross_attention_dim
|
28 |
+
self.clip_extra_context_tokens = clip_extra_context_tokens
|
29 |
+
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
30 |
+
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
31 |
+
|
32 |
+
def forward(self, image_embeds):
|
33 |
+
embeds = image_embeds
|
34 |
+
clip_extra_context_tokens = self.proj(embeds).reshape(
|
35 |
+
-1, self.clip_extra_context_tokens, self.cross_attention_dim
|
36 |
+
)
|
37 |
+
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
38 |
+
return clip_extra_context_tokens
|
39 |
+
|
40 |
+
class CLIPAway:
|
41 |
+
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, alpha_clip_path, config, alpha_clip_id="ViT-L/14", device="cuda", num_tokens=4):
|
42 |
+
super().__init__()
|
43 |
+
self.device = device
|
44 |
+
self.ipadapter_image_encoder_path = image_encoder_path
|
45 |
+
self.ipadapter_ckpt = ip_ckpt
|
46 |
+
self.num_tokens = num_tokens
|
47 |
+
|
48 |
+
self.pipe = sd_pipe.to(self.device)
|
49 |
+
self.set_ip_adapter()
|
50 |
+
alpha_clip_model, alpha_clip_preprocess = alpha_clip.load(alpha_clip_id, alpha_vision_ckpt_pth=alpha_clip_path, device=device)
|
51 |
+
|
52 |
+
# load image encoder
|
53 |
+
self.image_encoder = alpha_clip_model.visual.to(self.device, dtype=torch.float32)
|
54 |
+
|
55 |
+
self.clip_proj = CLIPVisionModelWithProjection.from_pretrained(self.ipadapter_image_encoder_path).to(
|
56 |
+
self.device, dtype=torch.float32
|
57 |
+
)
|
58 |
+
self.alpha_clip_image_processor = alpha_clip_preprocess
|
59 |
+
|
60 |
+
# preprocess mask transformation for alpha clip
|
61 |
+
if "@336" in alpha_clip_id:
|
62 |
+
self.mask_transform = transforms.Compose([
|
63 |
+
transforms.ToTensor(),
|
64 |
+
transforms.Resize((336, 336)), # change to (336,336) when using ViT-L/14@336px
|
65 |
+
transforms.Normalize(0.5, 0.26)
|
66 |
+
])
|
67 |
+
else:
|
68 |
+
self.mask_transform = transforms.Compose([
|
69 |
+
transforms.ToTensor(),
|
70 |
+
transforms.Resize((224, 224)), # change to (336,336) when using ViT-L/14@336px
|
71 |
+
transforms.Normalize(0.5, 0.26)
|
72 |
+
])
|
73 |
+
# image proj model
|
74 |
+
self.image_proj_model = self.init_proj()
|
75 |
+
|
76 |
+
self.load_ip_adapter()
|
77 |
+
self.mlp_projection_layer = self.generate_projection_layer(config)
|
78 |
+
|
79 |
+
print(config.mlp_projection_layer_ckpt_path, type(config.mlp_projection_layer_ckpt_path) )
|
80 |
+
if config.mlp_projection_layer_ckpt_path is not None:
|
81 |
+
self.load_projection_layer(config.mlp_projection_layer_ckpt_path)
|
82 |
+
|
83 |
+
def load_projection_layer(self, path):
|
84 |
+
load_model(self.mlp_projection_layer, path)
|
85 |
+
print("Projection layer loaded from", path)
|
86 |
+
|
87 |
+
def generate_projection_layer(self, config):
|
88 |
+
projection_layer = nn.ModuleList()
|
89 |
+
|
90 |
+
for i in range(config.number_of_hidden_layers):
|
91 |
+
if i < config.number_of_hidden_layers // 2:
|
92 |
+
projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.alpha_clip_embed_dim))
|
93 |
+
projection_layer.append(nn.LayerNorm(config.alpha_clip_embed_dim))
|
94 |
+
elif i == config.number_of_hidden_layers // 2:
|
95 |
+
projection_layer.append(nn.Linear(config.alpha_clip_embed_dim, config.ip_adapter_embed_dim))
|
96 |
+
projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
|
97 |
+
else:
|
98 |
+
projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
|
99 |
+
projection_layer.append(nn.LayerNorm(config.ip_adapter_embed_dim))
|
100 |
+
projection_layer.append(nn.GELU())
|
101 |
+
|
102 |
+
projection_layer.append(nn.Linear(config.ip_adapter_embed_dim, config.ip_adapter_embed_dim))
|
103 |
+
|
104 |
+
return nn.Sequential(*projection_layer).to(self.device).to(torch.float32)
|
105 |
+
|
106 |
+
def init_proj(self):
|
107 |
+
image_proj_model = ImageProjModel(
|
108 |
+
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
109 |
+
clip_embeddings_dim=self.clip_proj.config.projection_dim,
|
110 |
+
clip_extra_context_tokens=self.num_tokens,
|
111 |
+
).to(self.device, dtype=torch.float32)
|
112 |
+
return image_proj_model
|
113 |
+
|
114 |
+
def set_ip_adapter(self):
|
115 |
+
unet = self.pipe.unet
|
116 |
+
attn_procs = {}
|
117 |
+
for name in unet.attn_processors.keys():
|
118 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
119 |
+
if name.startswith("mid_block"):
|
120 |
+
hidden_size = unet.config.block_out_channels[-1]
|
121 |
+
elif name.startswith("up_blocks"):
|
122 |
+
block_id = int(name[len("up_blocks.")])
|
123 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
124 |
+
elif name.startswith("down_blocks"):
|
125 |
+
block_id = int(name[len("down_blocks.")])
|
126 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
127 |
+
if cross_attention_dim is None:
|
128 |
+
attn_procs[name] = AttnProcessor().to(self.device)
|
129 |
+
else:
|
130 |
+
attn_procs[name] = IPAttnProcessor(
|
131 |
+
hidden_size=hidden_size,
|
132 |
+
cross_attention_dim=cross_attention_dim,
|
133 |
+
scale=1.0,
|
134 |
+
num_tokens=self.num_tokens,
|
135 |
+
).to(self.device, dtype=torch.float32)
|
136 |
+
unet.set_attn_processor(attn_procs)
|
137 |
+
|
138 |
+
def get_alpha_clip_embeds(self, pil_image, alpha):
|
139 |
+
clip_image = [self.alpha_clip_image_processor(image) for image in pil_image]
|
140 |
+
clip_image = torch.stack(clip_image).to(self.device, dtype=torch.float32)
|
141 |
+
masks = [self.mask_transform(mask) for mask in alpha]
|
142 |
+
masks = torch.stack(masks).to(self.device, dtype=torch.float32)
|
143 |
+
|
144 |
+
return self.image_encoder(clip_image, masks)
|
145 |
+
|
146 |
+
def load_ip_adapter(self):
|
147 |
+
if os.path.splitext(self.ipadapter_ckpt)[-1] == ".safetensors":
|
148 |
+
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
149 |
+
with safe_open(self.ipadapter_ckpt, framework="pt", device="cpu") as f:
|
150 |
+
for key in f.keys():
|
151 |
+
if key.startswith("image_proj."):
|
152 |
+
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
153 |
+
elif key.startswith("ip_adapter."):
|
154 |
+
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
155 |
+
else:
|
156 |
+
state_dict = torch.load(self.ipadapter_ckpt, map_location="cpu")
|
157 |
+
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
158 |
+
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
159 |
+
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
160 |
+
|
161 |
+
def get_complement_of_mask(self, mask):
|
162 |
+
return Image.fromarray((255 - np.array(mask[0])).astype(np.uint8))
|
163 |
+
|
164 |
+
def clipaway_projection_block(self, bg_embeds, fg_embeds):
|
165 |
+
projected_vector_magnitude = bg_embeds[0].dot(fg_embeds[0]) / fg_embeds[0].norm()
|
166 |
+
projected_vector = projected_vector_magnitude * fg_embeds / fg_embeds.norm()
|
167 |
+
return bg_embeds - projected_vector
|
168 |
+
|
169 |
+
def get_focused_embeddings(self, pil_image, alpha, use_projection_block=False):
|
170 |
+
# get focused alpha clip embeds
|
171 |
+
clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha)
|
172 |
+
clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
|
173 |
+
|
174 |
+
# mlp projection
|
175 |
+
projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
|
176 |
+
projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
|
177 |
+
|
178 |
+
# ip adapter logic
|
179 |
+
image_prompt_embeds_fg = self.image_proj_model(projected_alpha_clip_embeds_fg)
|
180 |
+
image_prompt_embeds_bg = self.image_proj_model(projected_alpha_clip_embeds_bg)
|
181 |
+
uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
|
182 |
+
|
183 |
+
if use_projection_block:
|
184 |
+
# clipaway projection block
|
185 |
+
projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
|
186 |
+
image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
|
187 |
+
return image_prompt_embeds, image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
|
188 |
+
|
189 |
+
return image_prompt_embeds_fg, image_prompt_embeds_bg, uncond_image_prompt_embeds
|
190 |
+
|
191 |
+
|
192 |
+
def get_ipadapter_embeds(self, pil_image=None, alpha=None):
|
193 |
+
# get focused alpha clip embeds
|
194 |
+
clip_image_embeds_fg = self.get_alpha_clip_embeds(pil_image, alpha)
|
195 |
+
clip_image_embeds_bg = self.get_alpha_clip_embeds(pil_image, [self.get_complement_of_mask(alpha)])
|
196 |
+
|
197 |
+
# mlp projection
|
198 |
+
projected_alpha_clip_embeds_fg = self.mlp_projection_layer(clip_image_embeds_fg)
|
199 |
+
projected_alpha_clip_embeds_bg = self.mlp_projection_layer(clip_image_embeds_bg)
|
200 |
+
|
201 |
+
# clipaway projection block
|
202 |
+
projected_alpha_clip_embeds = self.clipaway_projection_block(projected_alpha_clip_embeds_bg, projected_alpha_clip_embeds_fg)
|
203 |
+
|
204 |
+
# ip adapter logic
|
205 |
+
image_prompt_embeds = self.image_proj_model(projected_alpha_clip_embeds)
|
206 |
+
uncond_image_prompt_embeds = self.image_proj_model(self.mlp_projection_layer(torch.zeros_like(clip_image_embeds_fg)))
|
207 |
+
|
208 |
+
return image_prompt_embeds, uncond_image_prompt_embeds
|
209 |
+
|
210 |
+
|
211 |
+
def set_scale(self, scale):
|
212 |
+
for attn_processor in self.pipe.unet.attn_processors.values():
|
213 |
+
if isinstance(attn_processor, IPAttnProcessor):
|
214 |
+
attn_processor.scale = scale
|
215 |
+
|
216 |
+
@torch.inference_mode()
|
217 |
+
def generate(
|
218 |
+
self,
|
219 |
+
pil_image=None,
|
220 |
+
alpha=None,
|
221 |
+
prompt=None,
|
222 |
+
negative_prompt=None,
|
223 |
+
image_prompt_embeds=None,
|
224 |
+
uncond_image_prompt_embeds=None,
|
225 |
+
scale=1.0,
|
226 |
+
num_samples=1,
|
227 |
+
seed=None,
|
228 |
+
guidance_scale=7.5,
|
229 |
+
num_inference_steps=50,
|
230 |
+
**kwargs,
|
231 |
+
):
|
232 |
+
self.set_scale(scale)
|
233 |
+
num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
|
234 |
+
|
235 |
+
if prompt is None:
|
236 |
+
prompt = "best quality, high quality"
|
237 |
+
if negative_prompt is None:
|
238 |
+
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
239 |
+
|
240 |
+
if not isinstance(prompt, List):
|
241 |
+
prompt = [prompt] * num_prompts
|
242 |
+
if not isinstance(negative_prompt, List):
|
243 |
+
negative_prompt = [negative_prompt] * num_prompts
|
244 |
+
|
245 |
+
if image_prompt_embeds is None or uncond_image_prompt_embeds is None:
|
246 |
+
image_prompt_embeds, uncond_image_prompt_embeds= self.get_ipadapter_embeds(pil_image=pil_image, alpha=alpha)
|
247 |
+
else:
|
248 |
+
image_prompt_embeds = image_prompt_embeds.to(self.device)
|
249 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.to(self.device)
|
250 |
+
|
251 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
252 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed, seq_len, -1)
|
253 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed, seq_len, -1)
|
254 |
+
|
255 |
+
with torch.inference_mode():
|
256 |
+
prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
|
257 |
+
prompt,
|
258 |
+
device=self.device,
|
259 |
+
num_images_per_prompt=num_samples,
|
260 |
+
do_classifier_free_guidance=True,
|
261 |
+
negative_prompt=negative_prompt,
|
262 |
+
)
|
263 |
+
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
264 |
+
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
265 |
+
|
266 |
+
generator = get_generator(seed, self.device)
|
267 |
+
|
268 |
+
images = self.pipe(
|
269 |
+
prompt_embeds=prompt_embeds,
|
270 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
271 |
+
guidance_scale=guidance_scale,
|
272 |
+
num_inference_steps=num_inference_steps,
|
273 |
+
generator=generator,
|
274 |
+
image=pil_image,
|
275 |
+
mask_image=alpha,
|
276 |
+
**kwargs,
|
277 |
+
).images
|
278 |
+
|
279 |
+
return images
|
280 |
+
|
model/resampler.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
taken from https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
3 |
+
"""
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from einops import rearrange
|
9 |
+
from einops.layers.torch import Rearrange
|
10 |
+
|
11 |
+
|
12 |
+
# FFN
|
13 |
+
def FeedForward(dim, mult=4):
|
14 |
+
inner_dim = int(dim * mult)
|
15 |
+
return nn.Sequential(
|
16 |
+
nn.LayerNorm(dim),
|
17 |
+
nn.Linear(dim, inner_dim, bias=False),
|
18 |
+
nn.GELU(),
|
19 |
+
nn.Linear(inner_dim, dim, bias=False),
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def reshape_tensor(x, heads):
|
24 |
+
bs, length, width = x.shape
|
25 |
+
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
26 |
+
x = x.view(bs, length, heads, -1)
|
27 |
+
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
28 |
+
x = x.transpose(1, 2)
|
29 |
+
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
30 |
+
x = x.reshape(bs, heads, length, -1)
|
31 |
+
return x
|
32 |
+
|
33 |
+
|
34 |
+
class PerceiverAttention(nn.Module):
|
35 |
+
def __init__(self, *, dim, dim_head=64, heads=8):
|
36 |
+
super().__init__()
|
37 |
+
self.scale = dim_head**-0.5
|
38 |
+
self.dim_head = dim_head
|
39 |
+
self.heads = heads
|
40 |
+
inner_dim = dim_head * heads
|
41 |
+
|
42 |
+
self.norm1 = nn.LayerNorm(dim)
|
43 |
+
self.norm2 = nn.LayerNorm(dim)
|
44 |
+
|
45 |
+
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
46 |
+
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
47 |
+
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
48 |
+
|
49 |
+
def forward(self, x, latents):
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
x (torch.Tensor): image features
|
53 |
+
shape (b, n1, D)
|
54 |
+
latent (torch.Tensor): latent features
|
55 |
+
shape (b, n2, D)
|
56 |
+
"""
|
57 |
+
x = self.norm1(x)
|
58 |
+
latents = self.norm2(latents)
|
59 |
+
|
60 |
+
b, l, _ = latents.shape
|
61 |
+
|
62 |
+
q = self.to_q(latents)
|
63 |
+
kv_input = torch.cat((x, latents), dim=-2)
|
64 |
+
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
65 |
+
|
66 |
+
q = reshape_tensor(q, self.heads)
|
67 |
+
k = reshape_tensor(k, self.heads)
|
68 |
+
v = reshape_tensor(v, self.heads)
|
69 |
+
|
70 |
+
# attention
|
71 |
+
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
72 |
+
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
73 |
+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
74 |
+
out = weight @ v
|
75 |
+
|
76 |
+
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
77 |
+
|
78 |
+
return self.to_out(out)
|
79 |
+
|
80 |
+
|
81 |
+
class Resampler(nn.Module):
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
dim=1024,
|
85 |
+
depth=8,
|
86 |
+
dim_head=64,
|
87 |
+
heads=16,
|
88 |
+
num_queries=8,
|
89 |
+
embedding_dim=768,
|
90 |
+
output_dim=1024,
|
91 |
+
ff_mult=4,
|
92 |
+
max_seq_len: int = 257, # CLIP tokens + CLS token
|
93 |
+
apply_pos_emb: bool = False,
|
94 |
+
num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
|
95 |
+
):
|
96 |
+
super().__init__()
|
97 |
+
self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
|
98 |
+
|
99 |
+
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
100 |
+
|
101 |
+
self.proj_in = nn.Linear(embedding_dim, dim)
|
102 |
+
|
103 |
+
self.proj_out = nn.Linear(dim, output_dim)
|
104 |
+
self.norm_out = nn.LayerNorm(output_dim)
|
105 |
+
|
106 |
+
self.to_latents_from_mean_pooled_seq = (
|
107 |
+
nn.Sequential(
|
108 |
+
nn.LayerNorm(dim),
|
109 |
+
nn.Linear(dim, dim * num_latents_mean_pooled),
|
110 |
+
Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
|
111 |
+
)
|
112 |
+
if num_latents_mean_pooled > 0
|
113 |
+
else None
|
114 |
+
)
|
115 |
+
|
116 |
+
self.layers = nn.ModuleList([])
|
117 |
+
for _ in range(depth):
|
118 |
+
self.layers.append(
|
119 |
+
nn.ModuleList(
|
120 |
+
[
|
121 |
+
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
122 |
+
FeedForward(dim=dim, mult=ff_mult),
|
123 |
+
]
|
124 |
+
)
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
if self.pos_emb is not None:
|
129 |
+
n, device = x.shape[1], x.device
|
130 |
+
pos_emb = self.pos_emb(torch.arange(n, device=device))
|
131 |
+
x = x + pos_emb
|
132 |
+
|
133 |
+
latents = self.latents.repeat(x.size(0), 1, 1)
|
134 |
+
|
135 |
+
x = self.proj_in(x)
|
136 |
+
|
137 |
+
if self.to_latents_from_mean_pooled_seq:
|
138 |
+
meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
|
139 |
+
meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
|
140 |
+
latents = torch.cat((meanpooled_latents, latents), dim=-2)
|
141 |
+
|
142 |
+
for attn, ff in self.layers:
|
143 |
+
latents = attn(x, latents) + latents
|
144 |
+
latents = ff(latents) + latents
|
145 |
+
|
146 |
+
latents = self.proj_out(latents)
|
147 |
+
return self.norm_out(latents)
|
148 |
+
|
149 |
+
|
150 |
+
def masked_mean(t, *, dim, mask=None):
|
151 |
+
if mask is None:
|
152 |
+
return t.mean(dim=dim)
|
153 |
+
|
154 |
+
denom = mask.sum(dim=dim, keepdim=True)
|
155 |
+
mask = rearrange(mask, "b n -> b n 1")
|
156 |
+
masked_t = t.masked_fill(~mask, 0.0)
|
157 |
+
|
158 |
+
return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
|