Spaces:
Runtime error
Runtime error
Manikandan97
commited on
Commit
•
ff271bf
1
Parent(s):
0c9d86f
Pluid Updated
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +201 -0
- README.md +102 -10
- app.py +224 -0
- app_flux.py +326 -0
- docs/pulid_for_flux.md +81 -0
- docs/v1.1_preview.md +14 -0
- eva_clip/__init__.py +11 -0
- eva_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
- eva_clip/constants.py +2 -0
- eva_clip/eva_vit_model.py +548 -0
- eva_clip/factory.py +517 -0
- eva_clip/hf_configs.py +57 -0
- eva_clip/hf_model.py +248 -0
- eva_clip/loss.py +138 -0
- eva_clip/model.py +439 -0
- eva_clip/model_configs/EVA01-CLIP-B-16.json +19 -0
- eva_clip/model_configs/EVA01-CLIP-g-14-plus.json +24 -0
- eva_clip/model_configs/EVA01-CLIP-g-14.json +24 -0
- eva_clip/model_configs/EVA02-CLIP-B-16.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-L-14-336.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-L-14.json +29 -0
- eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json +25 -0
- eva_clip/model_configs/EVA02-CLIP-bigE-14.json +25 -0
- eva_clip/modified_resnet.py +181 -0
- eva_clip/openai.py +144 -0
- eva_clip/pretrained.py +332 -0
- eva_clip/rope.py +137 -0
- eva_clip/timm_model.py +122 -0
- eva_clip/tokenizer.py +201 -0
- eva_clip/transform.py +103 -0
- eva_clip/transformer.py +737 -0
- eva_clip/utils.py +326 -0
- example_inputs/hinton.jpeg +0 -0
- example_inputs/lecun.jpg +0 -0
- example_inputs/lifeifei.jpg +0 -0
- example_inputs/liuyifei.png +0 -0
- example_inputs/pengwei.jpg +3 -0
- example_inputs/rihanna.webp +0 -0
- example_inputs/zcy.webp +0 -0
- flux/__init__.py +11 -0
- flux/math.py +31 -0
- flux/model.py +157 -0
- flux/modules/__init__.py +0 -0
- flux/modules/autoencoder.py +312 -0
- flux/modules/conditioner.py +37 -0
- flux/modules/layers.py +253 -0
- flux/sampling.py +164 -0
- flux/util.py +191 -0
- pulid/attention_processor.py +422 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
example_inputs/pengwei.jpg filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,10 +1,102 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PuLID
|
2 |
+
|
3 |
+
### :open_book: PuLID: Pure and Lightning ID Customization via Contrastive Alignment
|
4 |
+
> [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2404.16022) [![xl](https://img.shields.io/badge/🤗-HuggingFaceDemo-orange)](https://huggingface.co/spaces/yanze/PuLID) [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX) <br>
|
5 |
+
> Zinan Guo*, Yanze Wu*✝, Zhuowei Chen, Lang Chen, Qian He <br>
|
6 |
+
> (*Equal Contribution, ✝Corresponding Author) <br>
|
7 |
+
> ByteDance Inc <br>
|
8 |
+
|
9 |
+
### :triangular_flag_on_post: Updates
|
10 |
+
* **2024.09.12**: 💥 We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md)
|
11 |
+
* **2024.05.23**: share the [preview of our upcoming v1.1 model](docs/v1.1_preview.md), please stay tuned
|
12 |
+
* **2024.05.01**: release v1 codes&models, also the [🤗HuggingFace Demo](https://huggingface.co/spaces/yanze/PuLID)
|
13 |
+
* **2024.04.25**: release arXiv paper.
|
14 |
+
|
15 |
+
## PuLID for FLUX
|
16 |
+
Please check the doc and demo of PuLID-FLUX [here](docs/pulid_for_flux.md).
|
17 |
+
|
18 |
+
We will actively update and maintain this repository in the near future, so please stay tuned.
|
19 |
+
|
20 |
+
### updates
|
21 |
+
- [x] Local gradio demo is ready now
|
22 |
+
- [x] Online HuggingFace demo is ready now [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX)
|
23 |
+
- [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
|
24 |
+
|
25 |
+
|
26 |
+
Below results are generated with PuLID-FLUX.
|
27 |
+
![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
|
28 |
+
|
29 |
+
|
30 |
+
## Examples
|
31 |
+
Images generated with our PuLID
|
32 |
+
![examples](https://github.com/ToTheBeginning/PuLID/assets/11482921/65610b0d-ba4f-4dc3-a74d-bd60f8f5ce37)
|
33 |
+
Applications
|
34 |
+
|
35 |
+
https://github.com/ToTheBeginning/PuLID/assets/11482921/9bdd0c8a-99e8-4eab-ab9e-39bf796cc6b8
|
36 |
+
|
37 |
+
## :wrench: Dependencies and Installation
|
38 |
+
- Python >= 3.9 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux) or [Miniconda](https://docs.conda.io/en/latest/miniconda.html))
|
39 |
+
- [PyTorch >= 2.0](https://pytorch.org/) if you don't need flux-dev-fp8, otherwise [PyTorch >= 2.4.1](https://pytorch.org/)
|
40 |
+
```bash
|
41 |
+
# clone PuLID repo
|
42 |
+
git clone https://github.com/ToTheBeginning/PuLID.git
|
43 |
+
cd PuLID
|
44 |
+
# create conda env
|
45 |
+
conda create --name pulid python=3.10
|
46 |
+
# activate env
|
47 |
+
conda activate pulid
|
48 |
+
# Install dependent packages
|
49 |
+
# 1. if you don't need flux-fp8, e.g., you are using xl or flux-bf16, install the following requirements.txt
|
50 |
+
pip install -r requirements.txt
|
51 |
+
# 2. if you need flux-fp8 (to put flux on consumer-grade gpu), install the following requirements_fp8.txt
|
52 |
+
pip install -r requirements_fp8.txt
|
53 |
+
```
|
54 |
+
|
55 |
+
## :zap: Quick Inference
|
56 |
+
### Local Gradio Demo
|
57 |
+
```bash
|
58 |
+
python app.py
|
59 |
+
```
|
60 |
+
|
61 |
+
### Online HuggingFace Demo
|
62 |
+
Thanks for the GPU grant from HuggingFace team, you can try PuLID HF demo in
|
63 |
+
[https://huggingface.co/spaces/yanze/PuLID](https://huggingface.co/spaces/yanze/PuLID)
|
64 |
+
|
65 |
+
## :paperclip: Related Resources
|
66 |
+
Following are some third-party implementations of PuLID we have found in the Internet.
|
67 |
+
We appreciate the efforts of the respective developers for making PuLID accessible to a wider audience.
|
68 |
+
If there are any PuLID based resources and applications that we have not mentioned here, please let us know,
|
69 |
+
and we will include them in this list.
|
70 |
+
|
71 |
+
#### Online Demo
|
72 |
+
- **Colab**: https://github.com/camenduru/PuLID-jupyter provided by [camenduru](https://github.com/camenduru)
|
73 |
+
- **Replicate**: https://replicate.com/zsxkib/pulid provided by [zsxkib](https://replicate.com/zsxkib)
|
74 |
+
|
75 |
+
#### ComfyUI
|
76 |
+
- https://github.com/cubiq/PuLID_ComfyUI provided by [cubiq](https://github.com/cubiq), native ComfyUI implementation
|
77 |
+
- https://github.com/ZHO-ZHO-ZHO/ComfyUI-PuLID-ZHO provided by [ZHO](https://github.com/ZHO-ZHO-ZHO), diffusers-based implementation
|
78 |
+
|
79 |
+
#### WebUI
|
80 |
+
- https://github.com/Mikubill/sd-webui-controlnet/pull/2838 provided by [huchenlei](https://github.com/huchenlei)
|
81 |
+
|
82 |
+
## Disclaimer
|
83 |
+
This project strives to impact the domain of AI-driven image generation positively. Users are granted the freedom to
|
84 |
+
create images using this tool, but they are expected to comply with local laws and utilize it responsibly.
|
85 |
+
The developers do not assume any responsibility for potential misuse by users.
|
86 |
+
|
87 |
+
|
88 |
+
## Citation
|
89 |
+
If PuLID is helpful, please help to ⭐ the repo.
|
90 |
+
|
91 |
+
If you find this project useful for your research, please consider citing our paper:
|
92 |
+
```bibtex
|
93 |
+
@article{guo2024pulid,
|
94 |
+
title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
|
95 |
+
author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
|
96 |
+
journal={arXiv preprint arXiv:2404.16022},
|
97 |
+
year={2024}
|
98 |
+
}
|
99 |
+
```
|
100 |
+
|
101 |
+
## :e-mail: Contact
|
102 |
+
If you have any comments or questions, please [open a new issue](https://github.com/ToTheBeginning/PuLID/issues/new/choose) or feel free to contact [Yanze Wu](https://tothebeginning.github.io/) and [Zinan Guo](mailto:guozinan.1@bytedance.com).
|
app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from pulid import attention_processor as attention
|
6 |
+
from pulid.pipeline import PuLIDPipeline
|
7 |
+
from pulid.utils import resize_numpy_image_long, seed_everything
|
8 |
+
|
9 |
+
torch.set_grad_enabled(False)
|
10 |
+
|
11 |
+
pipeline = PuLIDPipeline()
|
12 |
+
|
13 |
+
# other params
|
14 |
+
DEFAULT_NEGATIVE_PROMPT = (
|
15 |
+
'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
|
16 |
+
'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
|
17 |
+
'low resolution, partially rendered objects, deformed or partially rendered eyes, '
|
18 |
+
'deformed, deformed eyeballs, cross-eyed,blurry'
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
def run(*args):
|
23 |
+
id_image = args[0]
|
24 |
+
supp_images = args[1:4]
|
25 |
+
prompt, neg_prompt, scale, n_samples, seed, steps, H, W, id_scale, mode, id_mix = args[4:]
|
26 |
+
|
27 |
+
pipeline.debug_img_list = []
|
28 |
+
if mode == 'fidelity':
|
29 |
+
attention.NUM_ZERO = 8
|
30 |
+
attention.ORTHO = False
|
31 |
+
attention.ORTHO_v2 = True
|
32 |
+
elif mode == 'extremely style':
|
33 |
+
attention.NUM_ZERO = 16
|
34 |
+
attention.ORTHO = True
|
35 |
+
attention.ORTHO_v2 = False
|
36 |
+
else:
|
37 |
+
raise ValueError
|
38 |
+
|
39 |
+
if id_image is not None:
|
40 |
+
id_image = resize_numpy_image_long(id_image, 1024)
|
41 |
+
id_embeddings = pipeline.get_id_embedding(id_image)
|
42 |
+
for supp_id_image in supp_images:
|
43 |
+
if supp_id_image is not None:
|
44 |
+
supp_id_image = resize_numpy_image_long(supp_id_image, 1024)
|
45 |
+
supp_id_embeddings = pipeline.get_id_embedding(supp_id_image)
|
46 |
+
id_embeddings = torch.cat(
|
47 |
+
(id_embeddings, supp_id_embeddings if id_mix else supp_id_embeddings[:, :5]), dim=1
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
id_embeddings = None
|
51 |
+
|
52 |
+
seed_everything(seed)
|
53 |
+
ims = []
|
54 |
+
for _ in range(n_samples):
|
55 |
+
img = pipeline.inference(prompt, (1, H, W), neg_prompt, id_embeddings, id_scale, scale, steps)[0]
|
56 |
+
ims.append(np.array(img))
|
57 |
+
|
58 |
+
return ims, pipeline.debug_img_list
|
59 |
+
|
60 |
+
|
61 |
+
_HEADER_ = '''
|
62 |
+
<h2><b>Official Gradio Demo</b></h2><h2><a href='https://github.com/ToTheBeginning/PuLID' target='_blank'><b>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</b></a></h2>
|
63 |
+
|
64 |
+
**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
|
65 |
+
|
66 |
+
Code: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>ArXiv</a>.
|
67 |
+
|
68 |
+
❗️❗️❗️**Tips:**
|
69 |
+
- we provide some examples in the bottom, you can try these example prompts first
|
70 |
+
- a single ID image is usually sufficient, you can also supplement with additional auxiliary images
|
71 |
+
- We offer two modes: fidelity mode and extremely style mode. In most cases, the default fidelity mode should suffice. If you find that the generated results are not stylized enough, you can choose the extremely style mode.
|
72 |
+
|
73 |
+
''' # noqa E501
|
74 |
+
|
75 |
+
_CITE_ = r"""
|
76 |
+
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
|
77 |
+
---
|
78 |
+
🚀 **Share**
|
79 |
+
If you have generated satisfying or interesting images with PuLID, please share them with us or your friends!
|
80 |
+
|
81 |
+
📝 **Citation**
|
82 |
+
If you find our work useful for your research or applications, please cite using this bibtex:
|
83 |
+
```bibtex
|
84 |
+
@article{guo2024pulid,
|
85 |
+
title={PuLID: Pure and Lightning ID Customization via Contrastive Alignment},
|
86 |
+
author={Guo, Zinan and Wu, Yanze and Chen, Zhuowei and Chen, Lang and He, Qian},
|
87 |
+
journal={arXiv preprint arXiv:2404.16022},
|
88 |
+
year={2024}
|
89 |
+
}
|
90 |
+
```
|
91 |
+
|
92 |
+
📋 **License**
|
93 |
+
Apache-2.0 LICENSE. Please refer to the [LICENSE file](placeholder) for details.
|
94 |
+
|
95 |
+
📧 **Contact**
|
96 |
+
If you have any questions, feel free to open a discussion or contact us at <b>wuyanze123@gmail.com</b> or <b>guozinan.1@bytedance.com</b>.
|
97 |
+
""" # noqa E501
|
98 |
+
|
99 |
+
|
100 |
+
with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
|
101 |
+
gr.Markdown(_HEADER_)
|
102 |
+
with gr.Row():
|
103 |
+
with gr.Column():
|
104 |
+
with gr.Row():
|
105 |
+
face_image = gr.Image(label="ID image (main)", sources="upload", type="numpy", height=256)
|
106 |
+
supp_image1 = gr.Image(
|
107 |
+
label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
|
108 |
+
)
|
109 |
+
supp_image2 = gr.Image(
|
110 |
+
label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
|
111 |
+
)
|
112 |
+
supp_image3 = gr.Image(
|
113 |
+
label="Additional ID image (auxiliary)", sources="upload", type="numpy", height=256
|
114 |
+
)
|
115 |
+
prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
|
116 |
+
submit = gr.Button("Generate")
|
117 |
+
neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
|
118 |
+
scale = gr.Slider(
|
119 |
+
label="CFG, recommend value range [1, 1.5], 1 will be faster ",
|
120 |
+
value=1.2,
|
121 |
+
minimum=1,
|
122 |
+
maximum=1.5,
|
123 |
+
step=0.1,
|
124 |
+
)
|
125 |
+
n_samples = gr.Slider(label="Num samples", value=4, minimum=1, maximum=8, step=1)
|
126 |
+
seed = gr.Slider(
|
127 |
+
label="Seed", value=42, minimum=np.iinfo(np.uint32).min, maximum=np.iinfo(np.uint32).max, step=1
|
128 |
+
)
|
129 |
+
steps = gr.Slider(label="Steps", value=4, minimum=1, maximum=100, step=1)
|
130 |
+
with gr.Row():
|
131 |
+
H = gr.Slider(label="Height", value=1024, minimum=512, maximum=2024, step=64)
|
132 |
+
W = gr.Slider(label="Width", value=768, minimum=512, maximum=2024, step=64)
|
133 |
+
with gr.Row():
|
134 |
+
id_scale = gr.Slider(label="ID scale", minimum=0, maximum=5, step=0.05, value=0.8, interactive=True)
|
135 |
+
mode = gr.Dropdown(label="mode", choices=['fidelity', 'extremely style'], value='fidelity')
|
136 |
+
id_mix = gr.Checkbox(
|
137 |
+
label="ID Mix (if you want to mix two ID image, please turn this on, otherwise, turn this off)",
|
138 |
+
value=False,
|
139 |
+
)
|
140 |
+
|
141 |
+
gr.Markdown("## Examples")
|
142 |
+
example_inps = [
|
143 |
+
[
|
144 |
+
'portrait,cinematic,wolf ears,white hair',
|
145 |
+
'example_inputs/liuyifei.png',
|
146 |
+
'fidelity',
|
147 |
+
]
|
148 |
+
]
|
149 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='realistic')
|
150 |
+
|
151 |
+
example_inps = [
|
152 |
+
[
|
153 |
+
'portrait, impressionist painting, loose brushwork, vibrant color, light and shadow play',
|
154 |
+
'example_inputs/zcy.webp',
|
155 |
+
'fidelity',
|
156 |
+
]
|
157 |
+
]
|
158 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='painting style')
|
159 |
+
|
160 |
+
example_inps = [
|
161 |
+
[
|
162 |
+
'portrait, flat papercut style, silhouette, clean cuts, paper, sharp edges, minimalist,color block,man', # noqa E501
|
163 |
+
'example_inputs/lecun.jpg',
|
164 |
+
'fidelity',
|
165 |
+
]
|
166 |
+
]
|
167 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='papercut style')
|
168 |
+
|
169 |
+
example_inps = [
|
170 |
+
[
|
171 |
+
'woman,cartoon,solo,Popmart Blind Box, Super Mario, 3d',
|
172 |
+
'example_inputs/rihanna.webp',
|
173 |
+
'fidelity',
|
174 |
+
]
|
175 |
+
]
|
176 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='3d style')
|
177 |
+
|
178 |
+
example_inps = [
|
179 |
+
[
|
180 |
+
'portrait, the legend of zelda, anime',
|
181 |
+
'example_inputs/liuyifei.png',
|
182 |
+
'extremely style',
|
183 |
+
]
|
184 |
+
]
|
185 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, mode], label='anime style')
|
186 |
+
|
187 |
+
example_inps = [
|
188 |
+
[
|
189 |
+
'portrait, superman',
|
190 |
+
'example_inputs/lecun.jpg',
|
191 |
+
'example_inputs/lifeifei.jpg',
|
192 |
+
'fidelity',
|
193 |
+
True,
|
194 |
+
]
|
195 |
+
]
|
196 |
+
gr.Examples(examples=example_inps, inputs=[prompt, face_image, supp_image1, mode, id_mix], label='id mix')
|
197 |
+
|
198 |
+
with gr.Column():
|
199 |
+
output = gr.Gallery(label='Output', elem_id="gallery")
|
200 |
+
intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
|
201 |
+
gr.Markdown(_CITE_)
|
202 |
+
|
203 |
+
inps = [
|
204 |
+
face_image,
|
205 |
+
supp_image1,
|
206 |
+
supp_image2,
|
207 |
+
supp_image3,
|
208 |
+
prompt,
|
209 |
+
neg_prompt,
|
210 |
+
scale,
|
211 |
+
n_samples,
|
212 |
+
seed,
|
213 |
+
steps,
|
214 |
+
H,
|
215 |
+
W,
|
216 |
+
id_scale,
|
217 |
+
mode,
|
218 |
+
id_mix,
|
219 |
+
]
|
220 |
+
submit.click(fn=run, inputs=inps, outputs=[output, intermediate_output])
|
221 |
+
|
222 |
+
|
223 |
+
demo.queue(max_size=3)
|
224 |
+
demo.launch(server_name='0.0.0.0')
|
app_flux.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
|
9 |
+
from flux.util import (
|
10 |
+
SamplingOptions,
|
11 |
+
load_ae,
|
12 |
+
load_clip,
|
13 |
+
load_flow_model,
|
14 |
+
load_flow_model_quintized,
|
15 |
+
load_t5,
|
16 |
+
)
|
17 |
+
from pulid.pipeline_flux import PuLIDPipeline
|
18 |
+
from pulid.utils import resize_numpy_image_long
|
19 |
+
|
20 |
+
|
21 |
+
def get_models(name: str, device: torch.device, offload: bool, fp8: bool):
|
22 |
+
t5 = load_t5(device, max_length=128)
|
23 |
+
clip = load_clip(device)
|
24 |
+
if fp8:
|
25 |
+
model = load_flow_model_quintized(name, device="cpu" if offload else device)
|
26 |
+
else:
|
27 |
+
model = load_flow_model(name, device="cpu" if offload else device)
|
28 |
+
model.eval()
|
29 |
+
ae = load_ae(name, device="cpu" if offload else device)
|
30 |
+
return model, ae, t5, clip
|
31 |
+
|
32 |
+
|
33 |
+
class FluxGenerator:
|
34 |
+
def __init__(self, model_name: str, device: str, offload: bool, aggressive_offload: bool, args):
|
35 |
+
self.device = torch.device(device)
|
36 |
+
self.offload = offload
|
37 |
+
self.aggressive_offload = aggressive_offload
|
38 |
+
self.model_name = model_name
|
39 |
+
self.model, self.ae, self.t5, self.clip = get_models(
|
40 |
+
model_name,
|
41 |
+
device=self.device,
|
42 |
+
offload=self.offload,
|
43 |
+
fp8=args.fp8,
|
44 |
+
)
|
45 |
+
self.pulid_model = PuLIDPipeline(self.model, device="cpu" if offload else device, weight_dtype=torch.bfloat16,
|
46 |
+
onnx_provider=args.onnx_provider)
|
47 |
+
if offload:
|
48 |
+
self.pulid_model.face_helper.face_det.mean_tensor = self.pulid_model.face_helper.face_det.mean_tensor.to(torch.device("cuda"))
|
49 |
+
self.pulid_model.face_helper.face_det.device = torch.device("cuda")
|
50 |
+
self.pulid_model.face_helper.device = torch.device("cuda")
|
51 |
+
self.pulid_model.device = torch.device("cuda")
|
52 |
+
self.pulid_model.load_pretrain(args.pretrained_model)
|
53 |
+
|
54 |
+
@torch.inference_mode()
|
55 |
+
def generate_image(
|
56 |
+
self,
|
57 |
+
width,
|
58 |
+
height,
|
59 |
+
num_steps,
|
60 |
+
start_step,
|
61 |
+
guidance,
|
62 |
+
seed,
|
63 |
+
prompt,
|
64 |
+
id_image=None,
|
65 |
+
id_weight=1.0,
|
66 |
+
neg_prompt="",
|
67 |
+
true_cfg=1.0,
|
68 |
+
timestep_to_start_cfg=1,
|
69 |
+
max_sequence_length=128,
|
70 |
+
):
|
71 |
+
self.t5.max_length = max_sequence_length
|
72 |
+
|
73 |
+
seed = int(seed)
|
74 |
+
if seed == -1:
|
75 |
+
seed = None
|
76 |
+
|
77 |
+
opts = SamplingOptions(
|
78 |
+
prompt=prompt,
|
79 |
+
width=width,
|
80 |
+
height=height,
|
81 |
+
num_steps=num_steps,
|
82 |
+
guidance=guidance,
|
83 |
+
seed=seed,
|
84 |
+
)
|
85 |
+
|
86 |
+
if opts.seed is None:
|
87 |
+
opts.seed = torch.Generator(device="cpu").seed()
|
88 |
+
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
|
89 |
+
t0 = time.perf_counter()
|
90 |
+
|
91 |
+
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
|
92 |
+
|
93 |
+
# prepare input
|
94 |
+
x = get_noise(
|
95 |
+
1,
|
96 |
+
opts.height,
|
97 |
+
opts.width,
|
98 |
+
device=self.device,
|
99 |
+
dtype=torch.bfloat16,
|
100 |
+
seed=opts.seed,
|
101 |
+
)
|
102 |
+
timesteps = get_schedule(
|
103 |
+
opts.num_steps,
|
104 |
+
x.shape[-1] * x.shape[-2] // 4,
|
105 |
+
shift=True,
|
106 |
+
)
|
107 |
+
|
108 |
+
if self.offload:
|
109 |
+
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
|
110 |
+
inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
|
111 |
+
inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None
|
112 |
+
|
113 |
+
# offload TEs to CPU, load processor models and id encoder to gpu
|
114 |
+
if self.offload:
|
115 |
+
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
|
116 |
+
torch.cuda.empty_cache()
|
117 |
+
self.pulid_model.components_to_device(torch.device("cuda"))
|
118 |
+
|
119 |
+
if id_image is not None:
|
120 |
+
id_image = resize_numpy_image_long(id_image, 1024)
|
121 |
+
id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg)
|
122 |
+
else:
|
123 |
+
id_embeddings = None
|
124 |
+
uncond_id_embeddings = None
|
125 |
+
|
126 |
+
# offload processor models and id encoder to CPU, load dit model to gpu
|
127 |
+
if self.offload:
|
128 |
+
self.pulid_model.components_to_device(torch.device("cpu"))
|
129 |
+
torch.cuda.empty_cache()
|
130 |
+
if self.aggressive_offload:
|
131 |
+
self.model.components_to_gpu()
|
132 |
+
else:
|
133 |
+
self.model = self.model.to(self.device)
|
134 |
+
|
135 |
+
# denoise initial noise
|
136 |
+
x = denoise(
|
137 |
+
self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight,
|
138 |
+
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg,
|
139 |
+
timestep_to_start_cfg=timestep_to_start_cfg,
|
140 |
+
neg_txt=inp_neg["txt"] if use_true_cfg else None,
|
141 |
+
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None,
|
142 |
+
neg_vec=inp_neg["vec"] if use_true_cfg else None,
|
143 |
+
aggressive_offload=self.aggressive_offload,
|
144 |
+
)
|
145 |
+
|
146 |
+
# offload model, load autoencoder to gpu
|
147 |
+
if self.offload:
|
148 |
+
self.model.cpu()
|
149 |
+
torch.cuda.empty_cache()
|
150 |
+
self.ae.decoder.to(x.device)
|
151 |
+
|
152 |
+
# decode latents to pixel space
|
153 |
+
x = unpack(x.float(), opts.height, opts.width)
|
154 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
|
155 |
+
x = self.ae.decode(x)
|
156 |
+
|
157 |
+
if self.offload:
|
158 |
+
self.ae.decoder.cpu()
|
159 |
+
torch.cuda.empty_cache()
|
160 |
+
|
161 |
+
t1 = time.perf_counter()
|
162 |
+
|
163 |
+
print(f"Done in {t1 - t0:.1f}s.")
|
164 |
+
# bring into PIL format
|
165 |
+
x = x.clamp(-1, 1)
|
166 |
+
# x = embed_watermark(x.float())
|
167 |
+
x = rearrange(x[0], "c h w -> h w c")
|
168 |
+
|
169 |
+
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
|
170 |
+
return img, str(opts.seed), self.pulid_model.debug_img_list
|
171 |
+
|
172 |
+
_HEADER_ = '''
|
173 |
+
<div style="text-align: center; max-width: 650px; margin: 0 auto;">
|
174 |
+
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1>
|
175 |
+
<p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p>
|
176 |
+
</div>
|
177 |
+
|
178 |
+
❗️❗️❗️**Tips:**
|
179 |
+
- `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value.
|
180 |
+
- `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to the [doc](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#useful-tips).
|
181 |
+
- please refer to the <a href='https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc.
|
182 |
+
- we provide some examples in the bottom, you can try these example prompts first
|
183 |
+
|
184 |
+
''' # noqa E501
|
185 |
+
|
186 |
+
_CITE_ = r"""
|
187 |
+
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks!
|
188 |
+
---
|
189 |
+
|
190 |
+
📧 **Contact**
|
191 |
+
If you have any questions or feedbacks, feel free to open a discussion or contact <b>wuyanze123@gmail.com</b>.
|
192 |
+
""" # noqa E501
|
193 |
+
|
194 |
+
|
195 |
+
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu",
|
196 |
+
offload: bool = False, aggressive_offload: bool = False):
|
197 |
+
generator = FluxGenerator(model_name, device, offload, aggressive_offload, args)
|
198 |
+
|
199 |
+
with gr.Blocks() as demo:
|
200 |
+
gr.Markdown(_HEADER_)
|
201 |
+
|
202 |
+
with gr.Row():
|
203 |
+
with gr.Column():
|
204 |
+
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic")
|
205 |
+
id_image = gr.Image(label="ID Image")
|
206 |
+
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight")
|
207 |
+
|
208 |
+
width = gr.Slider(256, 1536, 896, step=16, label="Width")
|
209 |
+
height = gr.Slider(256, 1536, 1152, step=16, label="Height")
|
210 |
+
num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps")
|
211 |
+
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID")
|
212 |
+
guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance")
|
213 |
+
seed = gr.Textbox(-1, label="Seed (-1 for random)")
|
214 |
+
max_sequence_length = gr.Slider(128, 512, 128, step=128,
|
215 |
+
label="max_sequence_length for prompt (T5), small will be faster")
|
216 |
+
|
217 |
+
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501
|
218 |
+
neg_prompt = gr.Textbox(
|
219 |
+
label="Negative Prompt",
|
220 |
+
value="bad quality, worst quality, text, signature, watermark, extra limbs")
|
221 |
+
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale")
|
222 |
+
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev)
|
223 |
+
|
224 |
+
generate_btn = gr.Button("Generate")
|
225 |
+
|
226 |
+
with gr.Column():
|
227 |
+
output_image = gr.Image(label="Generated Image")
|
228 |
+
seed_output = gr.Textbox(label="Used Seed")
|
229 |
+
intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev)
|
230 |
+
gr.Markdown(_CITE_)
|
231 |
+
|
232 |
+
with gr.Row(), gr.Column():
|
233 |
+
gr.Markdown("## Examples")
|
234 |
+
example_inps = [
|
235 |
+
[
|
236 |
+
'a woman holding sign with glowing green text \"PuLID for FLUX\"',
|
237 |
+
'example_inputs/liuyifei.png',
|
238 |
+
4, 4, 2680261499100305976, 1
|
239 |
+
],
|
240 |
+
[
|
241 |
+
'portrait, side view',
|
242 |
+
'example_inputs/liuyifei.png',
|
243 |
+
4, 4, 1205240166692517553, 1
|
244 |
+
],
|
245 |
+
[
|
246 |
+
'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501
|
247 |
+
'example_inputs/liuyifei.png',
|
248 |
+
4, 4, 6349424134217931066, 1
|
249 |
+
],
|
250 |
+
[
|
251 |
+
'a young child is eating Icecream',
|
252 |
+
'example_inputs/liuyifei.png',
|
253 |
+
4, 4, 10606046113565776207, 1
|
254 |
+
],
|
255 |
+
[
|
256 |
+
'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain',
|
257 |
+
'example_inputs/pengwei.jpg',
|
258 |
+
4, 4, 2410129802683836089, 1
|
259 |
+
],
|
260 |
+
[
|
261 |
+
'portrait, candle light',
|
262 |
+
'example_inputs/pengwei.jpg',
|
263 |
+
4, 4, 17522759474323955700, 1
|
264 |
+
],
|
265 |
+
[
|
266 |
+
'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501
|
267 |
+
'example_inputs/pengwei.jpg',
|
268 |
+
4, 4, 17733156847328193625, 1
|
269 |
+
],
|
270 |
+
[
|
271 |
+
'American Comics, 1boy',
|
272 |
+
'example_inputs/pengwei.jpg',
|
273 |
+
1, 4, 13223174453874179686, 1
|
274 |
+
],
|
275 |
+
[
|
276 |
+
'portrait, pixar',
|
277 |
+
'example_inputs/pengwei.jpg',
|
278 |
+
1, 4, 9445036702517583939, 1
|
279 |
+
],
|
280 |
+
]
|
281 |
+
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
|
282 |
+
label='fake CFG')
|
283 |
+
|
284 |
+
example_inps = [
|
285 |
+
[
|
286 |
+
'portrait, made of ice sculpture',
|
287 |
+
'example_inputs/lecun.jpg',
|
288 |
+
1, 1, 3811899118709451814, 5
|
289 |
+
],
|
290 |
+
]
|
291 |
+
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg],
|
292 |
+
label='true CFG')
|
293 |
+
|
294 |
+
generate_btn.click(
|
295 |
+
fn=generator.generate_image,
|
296 |
+
inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt,
|
297 |
+
true_cfg, timestep_to_start_cfg, max_sequence_length],
|
298 |
+
outputs=[output_image, seed_output, intermediate_output],
|
299 |
+
)
|
300 |
+
|
301 |
+
return demo
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == "__main__":
|
305 |
+
import argparse
|
306 |
+
|
307 |
+
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev")
|
308 |
+
parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'),
|
309 |
+
help="currently only support flux-dev")
|
310 |
+
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
|
311 |
+
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
|
312 |
+
parser.add_argument("--aggressive_offload", action="store_true", help="Offload model more aggressively to CPU when not in use, for 24G GPUs")
|
313 |
+
parser.add_argument("--fp8", action="store_true", help="use flux-dev-fp8 model")
|
314 |
+
parser.add_argument("--onnx_provider", type=str, default="gpu", choices=["gpu", "cpu"],
|
315 |
+
help="set onnx_provider to cpu (default gpu) can help reduce RAM usage, and when combined with"
|
316 |
+
"fp8 option, the peak RAM is under 15GB")
|
317 |
+
parser.add_argument("--port", type=int, default=8080, help="Port to use")
|
318 |
+
parser.add_argument("--dev", action='store_true', help="Development mode")
|
319 |
+
parser.add_argument("--pretrained_model", type=str, help='for development')
|
320 |
+
args = parser.parse_args()
|
321 |
+
|
322 |
+
if args.aggressive_offload:
|
323 |
+
args.offload = True
|
324 |
+
|
325 |
+
demo = create_demo(args, args.name, args.device, args.offload, args.aggressive_offload)
|
326 |
+
demo.launch(server_name='0.0.0.0', server_port=args.port)
|
docs/pulid_for_flux.md
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PuLID for FLUX
|
2 |
+
We are happy to release the **PuLID-FLUX-v0.9.0** model, which provides a tuning-free ID customization solution for FLUX.1-dev.
|
3 |
+
|
4 |
+
If PuLID-FLUX is helpful, please help to ⭐ this repo or recommend it to your friends 😊
|
5 |
+
|
6 |
+
## Inference
|
7 |
+
### Local Gradio Demo
|
8 |
+
You first need to follow the [dependencies-and-installation](../README.md#wrench-dependencies-and-installation) to set
|
9 |
+
up the environment, and download the `flux1-dev.safetensors` (if you want to use bf16 rather than fp8) and `ae.safetensors` from [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main).
|
10 |
+
The PuLID-FLUX model will be automatically downloaded from [huggingface](https://huggingface.co/guozinan/PuLID/tree/main).
|
11 |
+
|
12 |
+
There are following four options to run the gradio demo:
|
13 |
+
|
14 |
+
#### naive bf16
|
15 |
+
simply run `python app_flux.py`, the peak memory is under 45GB.
|
16 |
+
|
17 |
+
#### bf16 + offload
|
18 |
+
run `python app_flux.py --offload`, the peak memory is under 30GB.
|
19 |
+
|
20 |
+
#### fp8 + offload (for consumer-grade GPUs)
|
21 |
+
To use fp8, you need to make sure you have installed `requirements-fp8.txt`, it includes `optimum-quanto` and higher version of PyTorch.
|
22 |
+
We use `flux-dev-fp8` checkpoint from [XLabs-AI/flux-dev-fp8](https://huggingface.co/XLabs-AI/flux-dev-fp8), it will be automatically downloaded. You can also download it manually and put it in the models folder
|
23 |
+
|
24 |
+
Run `python app_flux.py --offload --fp8 --onnx_provider cpu`, the peak memory is under 15GB, this is for GPU with 16GB memory.
|
25 |
+
|
26 |
+
For 24GB graphic memory users, you can run `python app_flux.py --offload --fp8`, the peak memory is under 17GB.
|
27 |
+
|
28 |
+
However, there is a difference in image quality between fp8 and bf16, with some degradation in the former.
|
29 |
+
Specifically, the details of the face may be slightly worse, but the layout is similar. If you want the best results
|
30 |
+
of PuLID-FLUX or you have the resources, please use bf16 rather than fp8.
|
31 |
+
We have included a comparison in the table below.
|
32 |
+
|
33 |
+
| | case1 | case2 | case3 | case4 |
|
34 |
+
|------|:-------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------:|:-------------------------------------------------------------------------------------------:|:----------------------------------------------------------------------------------------:|
|
35 |
+
| bf16 | ![c1_bf16](https://github.com/user-attachments/assets/781b2102-d5fe-4786-b4d3-7b8df501c781) | ![c2_bf16](https://github.com/user-attachments/assets/6218a6ca-f07e-4a9a-ac63-896526ff52cf) | ![c3_bf16](https://github.com/user-attachments/assets/3b6675e5-d26e-4799-b0f3-72e4a7f9a771) |![c4_bf16](https://github.com/user-attachments/assets/b4e162ca-da8b-4e68-8d6b-ba1a674b2a0b)|
|
36 |
+
| fp8 | ![c1_fp8](https://github.com/user-attachments/assets/8547f020-bd39-4e9b-aa82-b85be4efc41c) | ![c2_fp8](https://github.com/user-attachments/assets/00d3d485-0298-4966-82e1-a31946797ac8) | ![c3_fp8](https://github.com/user-attachments/assets/b1c6a6b6-1140-49a3-93bd-1245ee5fef4c) |![c4_fp8](https://github.com/user-attachments/assets/62e512ca-6315-4a89-9350-430e20b86b36)|
|
37 |
+
|
38 |
+
|
39 |
+
#### bf16 + more agreesive offload
|
40 |
+
run `python app_flux.py --aggressive_offload`, the peak memory is around 23GB.
|
41 |
+
But it will be very, very slow. If you have better solution to run bf16 under 24GB, please let us know.
|
42 |
+
|
43 |
+
### Online Demo
|
44 |
+
- huggingface demo:
|
45 |
+
[https://huggingface.co/spaces/yanze/PuLID-FLUX](https://huggingface.co/spaces/yanze/PuLID-FLUX)
|
46 |
+
|
47 |
+
### ComfyUI
|
48 |
+
Please stay tuned for the community implementation
|
49 |
+
|
50 |
+
## Visual Results
|
51 |
+
![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848)
|
52 |
+
|
53 |
+
|
54 |
+
## Useful Tips
|
55 |
+
There are two parameters that are crucial and need to be set carefully:
|
56 |
+
|
57 |
+
1. `timestep to start inserting ID`: This parameter controls the timing of ID insertion. If set to 0, the ID starts being inserted to the DIT from the first timestep. The earlier it is inserted, the higher the ID fidelity will be, but the editability may decrease. The later it is inserted, the lower the fidelity to the ID, but the editability will increase, and the disruption to the original model behavior will also be smaller. For generating realistic images, we suggest setting this to 4. If you found the ID similarity is not high enough, you could try lowering this parameter accordingly. For generating stylized images, we suggest setting it to 0-1.
|
58 |
+
![start_id](https://github.com/user-attachments/assets/3866ffab-542d-4e2f-9a0c-6877c9158d49)
|
59 |
+
|
60 |
+
2. `true CFG scale`: FLUX.1-dev is a guidance distill model. The original CFG process, which required twice the number of inference steps, is distilled into a guidance scale, thereby modulating the DIT through the guidance scale to simulate the true CFG process with half the inference steps. We will refer to this as fake CFG in the following doc. Our PuLID-FLUX model can be tested under the fake CFG settings, and the guidance scale can be set to a commonly used value, such as 4. However, the model also supports using the real CFG for inference. We compare the results of using true CFG with the fake CFG in photorealistic scenarios below.
|
61 |
+
![fake_cfg_vs_true_cfg_fidelity](https://github.com/user-attachments/assets/73b44dc8-37c7-48c8-8f55-73882731126d)
|
62 |
+
As shown in the above image, in terms of ID fidelity, using fake CFG is similar to true CFG in most cases, except that in a few cases, true CFG achieves higher ID similarity. In terms of image aesthetics and facial naturalness, fake CFG performs better. However, by carefully adjusting hyperparameters, the performance of true CFG may be further improved, we leave this to the community to explore. Therefore, we recommend using fake CFG for photorealistic scenes. If you are not satisfy about the ID fidelity, you can try switching to true CFG. Additionally, as shown below, we have found that using fake CFG in stylized scenes sometimes results in lower ID similarity and poorer style response, so if you encounter these two issues in stylized scenes, please consider switching to true CFG.
|
63 |
+
![fake_cfg_vs_true_cfg_style](https://github.com/user-attachments/assets/fb042639-64e6-4bb3-a3a4-5c138793318e)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
## Some Technical Details
|
68 |
+
- We switch the ID encoder from an MLP structure to a Transformer structure. Interested users can refer to [source code](https://github.com/ToTheBeginning/PuLID/blob/cce7cdd65b5bf283c1a39c29f2726902a3c135ca/pulid/encoders_flux.py#L122)
|
69 |
+
- Inspired by [Flamingo](https://arxiv.org/abs/2204.14198), we insert additional cross-attention blocks every few DIT blocks to interact ID features with DIT image features
|
70 |
+
- We would like to clarify that the acceleration method (lile SDXL-Lightning) serves as an
|
71 |
+
optional acceleration trick, but it is not indispensable for training PuLID. We will update the arxiv paper with the relevant details in the near future. Please stay tuned.
|
72 |
+
|
73 |
+
|
74 |
+
## limitation
|
75 |
+
The model is currently in beta version, and we have observed that the ID fidelity may not be high for some male inputs, maybe the model requires more training. If the improved model is ready, we will release it here, so please stay tuned.
|
76 |
+
|
77 |
+
## License
|
78 |
+
As long as you use FLUX.1-dev model, you should follow the [FLUX.1-dev model license](https://github.com/black-forest-labs/flux/tree/main/model_licenses)
|
79 |
+
|
80 |
+
## contact
|
81 |
+
If you have any questions or suggestions about the model, please contact [Yanze Wu](https://tothebeginning.github.io/) or open an issue/discussion here.
|
docs/v1.1_preview.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PuLID v1.1 preview
|
2 |
+
## The improvements of PuLID v1.1
|
3 |
+
|
4 |
+
In PuLID v1.1, we have made the following improvements:
|
5 |
+
- **better naturalness**
|
6 |
+
- **stronger editability**
|
7 |
+
- **more compatible with community models**
|
8 |
+
|
9 |
+
### PuLID with RealVis-XL as base model. Zoom in for best view
|
10 |
+
![realvis](https://github.com/ToTheBeginning/PuLID/assets/169147031/d6aa288b-b826-41bb-a512-96f9d54b448f)
|
11 |
+
### PuLID with Juggernaut-XL-Lightning as base model. Zoom in for best view
|
12 |
+
![juggernautXL_lightning](https://github.com/ToTheBeginning/PuLID/assets/169147031/4371d6b2-1063-49be-9ff1-56db58140cfe)
|
13 |
+
### PuLID with Dreamshaper-XL-Lightning as base model. Zoom in for best view
|
14 |
+
![dreamshaper](https://github.com/ToTheBeginning/PuLID/assets/169147031/89a21ee0-25c1-4098-a868-59e3149fe10c)
|
eva_clip/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
2 |
+
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
|
3 |
+
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
|
4 |
+
from .loss import ClipLoss
|
5 |
+
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
|
6 |
+
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
|
7 |
+
from .openai import load_openai_model, list_openai_models
|
8 |
+
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
|
9 |
+
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
|
10 |
+
from .tokenizer import SimpleTokenizer, tokenize
|
11 |
+
from .transform import image_transform
|
eva_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
3 |
+
size 1356917
|
eva_clip/constants.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
2 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
eva_clip/eva_vit_model.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --------------------------------------------------------
|
2 |
+
# Adapted from https://github.com/microsoft/unilm/tree/master/beit
|
3 |
+
# --------------------------------------------------------
|
4 |
+
import math
|
5 |
+
import os
|
6 |
+
from functools import partial
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
try:
|
11 |
+
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
12 |
+
except:
|
13 |
+
from timm.layers import drop_path, to_2tuple, trunc_normal_
|
14 |
+
|
15 |
+
from .transformer import PatchDropout
|
16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
17 |
+
|
18 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
19 |
+
try:
|
20 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
21 |
+
except:
|
22 |
+
from torch.utils.checkpoint import checkpoint
|
23 |
+
else:
|
24 |
+
from torch.utils.checkpoint import checkpoint
|
25 |
+
|
26 |
+
try:
|
27 |
+
import xformers
|
28 |
+
import xformers.ops as xops
|
29 |
+
XFORMERS_IS_AVAILBLE = True
|
30 |
+
except:
|
31 |
+
XFORMERS_IS_AVAILBLE = False
|
32 |
+
|
33 |
+
class DropPath(nn.Module):
|
34 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
35 |
+
"""
|
36 |
+
def __init__(self, drop_prob=None):
|
37 |
+
super(DropPath, self).__init__()
|
38 |
+
self.drop_prob = drop_prob
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
return drop_path(x, self.drop_prob, self.training)
|
42 |
+
|
43 |
+
def extra_repr(self) -> str:
|
44 |
+
return 'p={}'.format(self.drop_prob)
|
45 |
+
|
46 |
+
|
47 |
+
class Mlp(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
in_features,
|
51 |
+
hidden_features=None,
|
52 |
+
out_features=None,
|
53 |
+
act_layer=nn.GELU,
|
54 |
+
norm_layer=nn.LayerNorm,
|
55 |
+
drop=0.,
|
56 |
+
subln=False,
|
57 |
+
|
58 |
+
):
|
59 |
+
super().__init__()
|
60 |
+
out_features = out_features or in_features
|
61 |
+
hidden_features = hidden_features or in_features
|
62 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
63 |
+
self.act = act_layer()
|
64 |
+
|
65 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
66 |
+
|
67 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
68 |
+
self.drop = nn.Dropout(drop)
|
69 |
+
|
70 |
+
def forward(self, x):
|
71 |
+
x = self.fc1(x)
|
72 |
+
x = self.act(x)
|
73 |
+
# x = self.drop(x)
|
74 |
+
# commit this for the orignal BERT implement
|
75 |
+
x = self.ffn_ln(x)
|
76 |
+
|
77 |
+
x = self.fc2(x)
|
78 |
+
x = self.drop(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class SwiGLU(nn.Module):
|
82 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
|
83 |
+
norm_layer=nn.LayerNorm, subln=False):
|
84 |
+
super().__init__()
|
85 |
+
out_features = out_features or in_features
|
86 |
+
hidden_features = hidden_features or in_features
|
87 |
+
|
88 |
+
self.w1 = nn.Linear(in_features, hidden_features)
|
89 |
+
self.w2 = nn.Linear(in_features, hidden_features)
|
90 |
+
|
91 |
+
self.act = act_layer()
|
92 |
+
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
|
93 |
+
self.w3 = nn.Linear(hidden_features, out_features)
|
94 |
+
|
95 |
+
self.drop = nn.Dropout(drop)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
x1 = self.w1(x)
|
99 |
+
x2 = self.w2(x)
|
100 |
+
hidden = self.act(x1) * x2
|
101 |
+
x = self.ffn_ln(hidden)
|
102 |
+
x = self.w3(x)
|
103 |
+
x = self.drop(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
class Attention(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
|
109 |
+
proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
|
110 |
+
super().__init__()
|
111 |
+
self.num_heads = num_heads
|
112 |
+
head_dim = dim // num_heads
|
113 |
+
if attn_head_dim is not None:
|
114 |
+
head_dim = attn_head_dim
|
115 |
+
all_head_dim = head_dim * self.num_heads
|
116 |
+
self.scale = qk_scale or head_dim ** -0.5
|
117 |
+
|
118 |
+
self.subln = subln
|
119 |
+
if self.subln:
|
120 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
|
121 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
122 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
|
123 |
+
else:
|
124 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
125 |
+
|
126 |
+
if qkv_bias:
|
127 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
128 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
129 |
+
else:
|
130 |
+
self.q_bias = None
|
131 |
+
self.v_bias = None
|
132 |
+
|
133 |
+
if window_size:
|
134 |
+
self.window_size = window_size
|
135 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
136 |
+
self.relative_position_bias_table = nn.Parameter(
|
137 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
138 |
+
# cls to token & token 2 cls & cls to cls
|
139 |
+
|
140 |
+
# get pair-wise relative position index for each token inside the window
|
141 |
+
coords_h = torch.arange(window_size[0])
|
142 |
+
coords_w = torch.arange(window_size[1])
|
143 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
144 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
145 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
146 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
147 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
148 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
149 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
150 |
+
relative_position_index = \
|
151 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
|
152 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
153 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
154 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
155 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
156 |
+
|
157 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
158 |
+
else:
|
159 |
+
self.window_size = None
|
160 |
+
self.relative_position_bias_table = None
|
161 |
+
self.relative_position_index = None
|
162 |
+
|
163 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
164 |
+
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
|
165 |
+
# self.proj = nn.Linear(all_head_dim, all_head_dim)
|
166 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
167 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
168 |
+
self.xattn = xattn
|
169 |
+
self.xattn_drop = attn_drop
|
170 |
+
|
171 |
+
self.rope = rope
|
172 |
+
|
173 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
174 |
+
B, N, C = x.shape
|
175 |
+
if self.subln:
|
176 |
+
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
|
177 |
+
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
|
178 |
+
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
|
179 |
+
|
180 |
+
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
|
181 |
+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
182 |
+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
183 |
+
else:
|
184 |
+
|
185 |
+
qkv_bias = None
|
186 |
+
if self.q_bias is not None:
|
187 |
+
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
188 |
+
|
189 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
190 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
|
191 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
192 |
+
|
193 |
+
if self.rope:
|
194 |
+
# slightly fast impl
|
195 |
+
q_t = q[:, :, 1:, :]
|
196 |
+
ro_q_t = self.rope(q_t)
|
197 |
+
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
|
198 |
+
|
199 |
+
k_t = k[:, :, 1:, :]
|
200 |
+
ro_k_t = self.rope(k_t)
|
201 |
+
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
|
202 |
+
|
203 |
+
if self.xattn:
|
204 |
+
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
|
205 |
+
k = k.permute(0, 2, 1, 3)
|
206 |
+
v = v.permute(0, 2, 1, 3)
|
207 |
+
|
208 |
+
x = xops.memory_efficient_attention(
|
209 |
+
q, k, v,
|
210 |
+
p=self.xattn_drop,
|
211 |
+
scale=self.scale,
|
212 |
+
)
|
213 |
+
x = x.reshape(B, N, -1)
|
214 |
+
x = self.inner_attn_ln(x)
|
215 |
+
x = self.proj(x)
|
216 |
+
x = self.proj_drop(x)
|
217 |
+
else:
|
218 |
+
q = q * self.scale
|
219 |
+
attn = (q @ k.transpose(-2, -1))
|
220 |
+
|
221 |
+
if self.relative_position_bias_table is not None:
|
222 |
+
relative_position_bias = \
|
223 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
224 |
+
self.window_size[0] * self.window_size[1] + 1,
|
225 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
226 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
227 |
+
attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
|
228 |
+
|
229 |
+
if rel_pos_bias is not None:
|
230 |
+
attn = attn + rel_pos_bias.type_as(attn)
|
231 |
+
|
232 |
+
if attn_mask is not None:
|
233 |
+
attn_mask = attn_mask.bool()
|
234 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
235 |
+
|
236 |
+
attn = attn.softmax(dim=-1)
|
237 |
+
attn = self.attn_drop(attn)
|
238 |
+
|
239 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
240 |
+
x = self.inner_attn_ln(x)
|
241 |
+
x = self.proj(x)
|
242 |
+
x = self.proj_drop(x)
|
243 |
+
return x
|
244 |
+
|
245 |
+
|
246 |
+
class Block(nn.Module):
|
247 |
+
|
248 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
249 |
+
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
250 |
+
window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
|
251 |
+
subln=False, naiveswiglu=False):
|
252 |
+
super().__init__()
|
253 |
+
self.norm1 = norm_layer(dim)
|
254 |
+
self.attn = Attention(
|
255 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
256 |
+
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
|
257 |
+
xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
|
258 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
259 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
260 |
+
self.norm2 = norm_layer(dim)
|
261 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
262 |
+
|
263 |
+
if naiveswiglu:
|
264 |
+
self.mlp = SwiGLU(
|
265 |
+
in_features=dim,
|
266 |
+
hidden_features=mlp_hidden_dim,
|
267 |
+
subln=subln,
|
268 |
+
norm_layer=norm_layer,
|
269 |
+
)
|
270 |
+
else:
|
271 |
+
self.mlp = Mlp(
|
272 |
+
in_features=dim,
|
273 |
+
hidden_features=mlp_hidden_dim,
|
274 |
+
act_layer=act_layer,
|
275 |
+
subln=subln,
|
276 |
+
drop=drop
|
277 |
+
)
|
278 |
+
|
279 |
+
if init_values is not None and init_values > 0:
|
280 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
281 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
|
282 |
+
else:
|
283 |
+
self.gamma_1, self.gamma_2 = None, None
|
284 |
+
|
285 |
+
self.postnorm = postnorm
|
286 |
+
|
287 |
+
def forward(self, x, rel_pos_bias=None, attn_mask=None):
|
288 |
+
if self.gamma_1 is None:
|
289 |
+
if self.postnorm:
|
290 |
+
x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
291 |
+
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
292 |
+
else:
|
293 |
+
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
294 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
295 |
+
else:
|
296 |
+
if self.postnorm:
|
297 |
+
x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
|
298 |
+
x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
|
299 |
+
else:
|
300 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
|
301 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
302 |
+
return x
|
303 |
+
|
304 |
+
|
305 |
+
class PatchEmbed(nn.Module):
|
306 |
+
""" Image to Patch Embedding
|
307 |
+
"""
|
308 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
309 |
+
super().__init__()
|
310 |
+
img_size = to_2tuple(img_size)
|
311 |
+
patch_size = to_2tuple(patch_size)
|
312 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
313 |
+
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
314 |
+
self.img_size = img_size
|
315 |
+
self.patch_size = patch_size
|
316 |
+
self.num_patches = num_patches
|
317 |
+
|
318 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
319 |
+
|
320 |
+
def forward(self, x, **kwargs):
|
321 |
+
B, C, H, W = x.shape
|
322 |
+
# FIXME look at relaxing size constraints
|
323 |
+
assert H == self.img_size[0] and W == self.img_size[1], \
|
324 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
325 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
326 |
+
return x
|
327 |
+
|
328 |
+
|
329 |
+
class RelativePositionBias(nn.Module):
|
330 |
+
|
331 |
+
def __init__(self, window_size, num_heads):
|
332 |
+
super().__init__()
|
333 |
+
self.window_size = window_size
|
334 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
335 |
+
self.relative_position_bias_table = nn.Parameter(
|
336 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
337 |
+
# cls to token & token 2 cls & cls to cls
|
338 |
+
|
339 |
+
# get pair-wise relative position index for each token inside the window
|
340 |
+
coords_h = torch.arange(window_size[0])
|
341 |
+
coords_w = torch.arange(window_size[1])
|
342 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
343 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
344 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
345 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
346 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
347 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
348 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
349 |
+
relative_position_index = \
|
350 |
+
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
|
351 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
352 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
353 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
354 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
355 |
+
|
356 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
357 |
+
|
358 |
+
def forward(self):
|
359 |
+
relative_position_bias = \
|
360 |
+
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
361 |
+
self.window_size[0] * self.window_size[1] + 1,
|
362 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
363 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
364 |
+
|
365 |
+
|
366 |
+
class EVAVisionTransformer(nn.Module):
|
367 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
368 |
+
"""
|
369 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
370 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
371 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
|
372 |
+
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
|
373 |
+
use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
|
374 |
+
pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False):
|
375 |
+
super().__init__()
|
376 |
+
|
377 |
+
if not XFORMERS_IS_AVAILBLE:
|
378 |
+
xattn = False
|
379 |
+
|
380 |
+
self.image_size = img_size
|
381 |
+
self.num_classes = num_classes
|
382 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
383 |
+
|
384 |
+
self.patch_embed = PatchEmbed(
|
385 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
386 |
+
num_patches = self.patch_embed.num_patches
|
387 |
+
|
388 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
389 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
390 |
+
if use_abs_pos_emb:
|
391 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
392 |
+
else:
|
393 |
+
self.pos_embed = None
|
394 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
395 |
+
|
396 |
+
if use_shared_rel_pos_bias:
|
397 |
+
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
|
398 |
+
else:
|
399 |
+
self.rel_pos_bias = None
|
400 |
+
|
401 |
+
if rope:
|
402 |
+
half_head_dim = embed_dim // num_heads // 2
|
403 |
+
hw_seq_len = img_size // patch_size
|
404 |
+
self.rope = VisionRotaryEmbeddingFast(
|
405 |
+
dim=half_head_dim,
|
406 |
+
pt_seq_len=pt_hw_seq_len,
|
407 |
+
ft_seq_len=hw_seq_len if intp_freq else None,
|
408 |
+
# patch_dropout=patch_dropout
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
self.rope = None
|
412 |
+
|
413 |
+
self.naiveswiglu = naiveswiglu
|
414 |
+
|
415 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
416 |
+
self.use_rel_pos_bias = use_rel_pos_bias
|
417 |
+
self.blocks = nn.ModuleList([
|
418 |
+
Block(
|
419 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
420 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
|
421 |
+
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
|
422 |
+
xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
|
423 |
+
for i in range(depth)])
|
424 |
+
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
425 |
+
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
426 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
427 |
+
|
428 |
+
if self.pos_embed is not None:
|
429 |
+
trunc_normal_(self.pos_embed, std=.02)
|
430 |
+
|
431 |
+
trunc_normal_(self.cls_token, std=.02)
|
432 |
+
# trunc_normal_(self.mask_token, std=.02)
|
433 |
+
|
434 |
+
self.apply(self._init_weights)
|
435 |
+
self.fix_init_weight()
|
436 |
+
|
437 |
+
if isinstance(self.head, nn.Linear):
|
438 |
+
trunc_normal_(self.head.weight, std=.02)
|
439 |
+
self.head.weight.data.mul_(init_scale)
|
440 |
+
self.head.bias.data.mul_(init_scale)
|
441 |
+
|
442 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
443 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
444 |
+
|
445 |
+
self.grad_checkpointing = grad_checkpointing
|
446 |
+
|
447 |
+
def fix_init_weight(self):
|
448 |
+
def rescale(param, layer_id):
|
449 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
450 |
+
|
451 |
+
for layer_id, layer in enumerate(self.blocks):
|
452 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
453 |
+
if self.naiveswiglu:
|
454 |
+
rescale(layer.mlp.w3.weight.data, layer_id + 1)
|
455 |
+
else:
|
456 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
457 |
+
|
458 |
+
def get_cast_dtype(self) -> torch.dtype:
|
459 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
460 |
+
|
461 |
+
def _init_weights(self, m):
|
462 |
+
if isinstance(m, nn.Linear):
|
463 |
+
trunc_normal_(m.weight, std=.02)
|
464 |
+
if m.bias is not None:
|
465 |
+
nn.init.constant_(m.bias, 0)
|
466 |
+
elif isinstance(m, nn.LayerNorm):
|
467 |
+
nn.init.constant_(m.bias, 0)
|
468 |
+
nn.init.constant_(m.weight, 1.0)
|
469 |
+
|
470 |
+
def get_num_layers(self):
|
471 |
+
return len(self.blocks)
|
472 |
+
|
473 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
474 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
475 |
+
for param in self.parameters():
|
476 |
+
param.requires_grad = False
|
477 |
+
|
478 |
+
@torch.jit.ignore
|
479 |
+
def set_grad_checkpointing(self, enable=True):
|
480 |
+
self.grad_checkpointing = enable
|
481 |
+
|
482 |
+
@torch.jit.ignore
|
483 |
+
def no_weight_decay(self):
|
484 |
+
return {'pos_embed', 'cls_token'}
|
485 |
+
|
486 |
+
def get_classifier(self):
|
487 |
+
return self.head
|
488 |
+
|
489 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
490 |
+
self.num_classes = num_classes
|
491 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
492 |
+
|
493 |
+
def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
|
494 |
+
|
495 |
+
x = self.patch_embed(x)
|
496 |
+
batch_size, seq_len, _ = x.size()
|
497 |
+
|
498 |
+
if shuffle:
|
499 |
+
idx = torch.randperm(x.shape[1]) + 1
|
500 |
+
zero = torch.LongTensor([0, ])
|
501 |
+
idx = torch.cat([zero, idx])
|
502 |
+
pos_embed = self.pos_embed[:, idx]
|
503 |
+
|
504 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
505 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
506 |
+
if shuffle:
|
507 |
+
x = x + pos_embed
|
508 |
+
elif self.pos_embed is not None:
|
509 |
+
x = x + self.pos_embed
|
510 |
+
x = self.pos_drop(x)
|
511 |
+
|
512 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
513 |
+
if os.getenv('RoPE') == '1':
|
514 |
+
if self.training and not isinstance(self.patch_dropout, nn.Identity):
|
515 |
+
x, patch_indices_keep = self.patch_dropout(x)
|
516 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
|
517 |
+
else:
|
518 |
+
self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
|
519 |
+
x = self.patch_dropout(x)
|
520 |
+
else:
|
521 |
+
x = self.patch_dropout(x)
|
522 |
+
|
523 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
524 |
+
hidden_states = []
|
525 |
+
for idx, blk in enumerate(self.blocks):
|
526 |
+
if (0 < idx <= 20) and (idx % 4 == 0) and return_hidden:
|
527 |
+
hidden_states.append(x)
|
528 |
+
if self.grad_checkpointing:
|
529 |
+
x = checkpoint(blk, x, (rel_pos_bias,))
|
530 |
+
else:
|
531 |
+
x = blk(x, rel_pos_bias=rel_pos_bias)
|
532 |
+
|
533 |
+
if not return_all_features:
|
534 |
+
x = self.norm(x)
|
535 |
+
if self.fc_norm is not None:
|
536 |
+
return self.fc_norm(x.mean(1)), hidden_states
|
537 |
+
else:
|
538 |
+
return x[:, 0], hidden_states
|
539 |
+
return x
|
540 |
+
|
541 |
+
def forward(self, x, return_all_features=False, return_hidden=False, shuffle=False):
|
542 |
+
if return_all_features:
|
543 |
+
return self.forward_features(x, return_all_features, return_hidden, shuffle)
|
544 |
+
x, hidden_states = self.forward_features(x, return_all_features, return_hidden, shuffle)
|
545 |
+
x = self.head(x)
|
546 |
+
if return_hidden:
|
547 |
+
return x, hidden_states
|
548 |
+
return x
|
eva_clip/factory.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import pathlib
|
5 |
+
import re
|
6 |
+
from copy import deepcopy
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Optional, Tuple, Union, Dict, Any
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
12 |
+
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
|
13 |
+
get_cast_dtype
|
14 |
+
from .openai import load_openai_model
|
15 |
+
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
|
16 |
+
from .transform import image_transform
|
17 |
+
from .tokenizer import HFTokenizer, tokenize
|
18 |
+
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
|
19 |
+
|
20 |
+
|
21 |
+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
|
22 |
+
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
|
23 |
+
|
24 |
+
|
25 |
+
def _natural_key(string_):
|
26 |
+
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
27 |
+
|
28 |
+
|
29 |
+
def _rescan_model_configs():
|
30 |
+
global _MODEL_CONFIGS
|
31 |
+
|
32 |
+
config_ext = ('.json',)
|
33 |
+
config_files = []
|
34 |
+
for config_path in _MODEL_CONFIG_PATHS:
|
35 |
+
if config_path.is_file() and config_path.suffix in config_ext:
|
36 |
+
config_files.append(config_path)
|
37 |
+
elif config_path.is_dir():
|
38 |
+
for ext in config_ext:
|
39 |
+
config_files.extend(config_path.glob(f'*{ext}'))
|
40 |
+
|
41 |
+
for cf in config_files:
|
42 |
+
with open(cf, "r", encoding="utf8") as f:
|
43 |
+
model_cfg = json.load(f)
|
44 |
+
if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
|
45 |
+
_MODEL_CONFIGS[cf.stem] = model_cfg
|
46 |
+
|
47 |
+
_MODEL_CONFIGS = dict(sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])))
|
48 |
+
|
49 |
+
|
50 |
+
_rescan_model_configs() # initial populate of model config registry
|
51 |
+
|
52 |
+
|
53 |
+
def list_models():
|
54 |
+
""" enumerate available model architectures based on config files """
|
55 |
+
return list(_MODEL_CONFIGS.keys())
|
56 |
+
|
57 |
+
|
58 |
+
def add_model_config(path):
|
59 |
+
""" add model config path or file and update registry """
|
60 |
+
if not isinstance(path, Path):
|
61 |
+
path = Path(path)
|
62 |
+
_MODEL_CONFIG_PATHS.append(path)
|
63 |
+
_rescan_model_configs()
|
64 |
+
|
65 |
+
|
66 |
+
def get_model_config(model_name):
|
67 |
+
if model_name in _MODEL_CONFIGS:
|
68 |
+
return deepcopy(_MODEL_CONFIGS[model_name])
|
69 |
+
else:
|
70 |
+
return None
|
71 |
+
|
72 |
+
|
73 |
+
def get_tokenizer(model_name):
|
74 |
+
config = get_model_config(model_name)
|
75 |
+
tokenizer = HFTokenizer(config['text_cfg']['hf_tokenizer_name']) if 'hf_tokenizer_name' in config['text_cfg'] else tokenize
|
76 |
+
return tokenizer
|
77 |
+
|
78 |
+
|
79 |
+
# loading openai CLIP weights when is_openai=True for training
|
80 |
+
def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: str='model|module|state_dict', is_openai: bool=False, skip_list: list=[]):
|
81 |
+
if is_openai:
|
82 |
+
model = torch.jit.load(checkpoint_path, map_location="cpu").eval()
|
83 |
+
state_dict = model.state_dict()
|
84 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
85 |
+
state_dict.pop(key, None)
|
86 |
+
else:
|
87 |
+
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
88 |
+
for mk in model_key.split('|'):
|
89 |
+
if isinstance(checkpoint, dict) and mk in checkpoint:
|
90 |
+
state_dict = checkpoint[mk]
|
91 |
+
break
|
92 |
+
else:
|
93 |
+
state_dict = checkpoint
|
94 |
+
if next(iter(state_dict.items()))[0].startswith('module'):
|
95 |
+
state_dict = {k[7:]: v for k, v in state_dict.items()}
|
96 |
+
|
97 |
+
for k in skip_list:
|
98 |
+
if k in list(state_dict.keys()):
|
99 |
+
logging.info(f"Removing key {k} from pretrained checkpoint")
|
100 |
+
del state_dict[k]
|
101 |
+
|
102 |
+
if os.getenv('RoPE') == '1':
|
103 |
+
for k in list(state_dict.keys()):
|
104 |
+
if 'freqs_cos' in k or 'freqs_sin' in k:
|
105 |
+
del state_dict[k]
|
106 |
+
return state_dict
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
def load_checkpoint(model, checkpoint_path, model_key="model|module|state_dict", strict=True):
|
111 |
+
state_dict = load_state_dict(checkpoint_path, model_key=model_key, is_openai=False)
|
112 |
+
# detect old format and make compatible with new format
|
113 |
+
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
|
114 |
+
state_dict = convert_to_custom_text_state_dict(state_dict)
|
115 |
+
if 'text.logit_scale' in state_dict and hasattr(model, 'logit_scale'):
|
116 |
+
state_dict['logit_scale'] = state_dict['text.logit_scale']
|
117 |
+
del state_dict['text.logit_scale']
|
118 |
+
|
119 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
120 |
+
if 'visual.positional_embedding' in state_dict:
|
121 |
+
resize_clip_pos_embed(state_dict, model)
|
122 |
+
# specified to eva_vit_model
|
123 |
+
elif 'visual.pos_embed' in state_dict:
|
124 |
+
resize_evaclip_pos_embed(state_dict, model)
|
125 |
+
|
126 |
+
# resize_clip_pos_embed(state_dict, model)
|
127 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
128 |
+
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
|
129 |
+
return incompatible_keys
|
130 |
+
|
131 |
+
def load_clip_visual_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
|
132 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
133 |
+
|
134 |
+
for k in list(state_dict.keys()):
|
135 |
+
if not k.startswith('visual.'):
|
136 |
+
del state_dict[k]
|
137 |
+
for k in list(state_dict.keys()):
|
138 |
+
if k.startswith('visual.'):
|
139 |
+
new_k = k[7:]
|
140 |
+
state_dict[new_k] = state_dict[k]
|
141 |
+
del state_dict[k]
|
142 |
+
return state_dict
|
143 |
+
|
144 |
+
def load_clip_text_state_dict(checkpoint_path: str, map_location: str='cpu', is_openai: bool=False, skip_list:list=[]):
|
145 |
+
state_dict = load_state_dict(checkpoint_path, map_location=map_location, is_openai=is_openai, skip_list=skip_list)
|
146 |
+
|
147 |
+
for k in list(state_dict.keys()):
|
148 |
+
if k.startswith('visual.'):
|
149 |
+
del state_dict[k]
|
150 |
+
return state_dict
|
151 |
+
|
152 |
+
def get_pretrained_tag(pretrained_model):
|
153 |
+
pretrained_model = pretrained_model.lower()
|
154 |
+
if "laion" in pretrained_model or "open_clip" in pretrained_model:
|
155 |
+
return "open_clip"
|
156 |
+
elif "openai" in pretrained_model:
|
157 |
+
return "clip"
|
158 |
+
elif "eva" in pretrained_model and "clip" in pretrained_model:
|
159 |
+
return "eva_clip"
|
160 |
+
else:
|
161 |
+
return "other"
|
162 |
+
|
163 |
+
def load_pretrained_checkpoint(
|
164 |
+
model,
|
165 |
+
visual_checkpoint_path,
|
166 |
+
text_checkpoint_path,
|
167 |
+
strict=True,
|
168 |
+
visual_model=None,
|
169 |
+
text_model=None,
|
170 |
+
model_key="model|module|state_dict",
|
171 |
+
skip_list=[]):
|
172 |
+
visual_tag = get_pretrained_tag(visual_model)
|
173 |
+
text_tag = get_pretrained_tag(text_model)
|
174 |
+
|
175 |
+
logging.info(f"num of model state_dict keys: {len(model.state_dict().keys())}")
|
176 |
+
visual_incompatible_keys, text_incompatible_keys = None, None
|
177 |
+
if visual_checkpoint_path:
|
178 |
+
if visual_tag == "eva_clip" or visual_tag == "open_clip":
|
179 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=False, skip_list=skip_list)
|
180 |
+
elif visual_tag == "clip":
|
181 |
+
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
|
182 |
+
else:
|
183 |
+
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
184 |
+
|
185 |
+
# resize_clip_pos_embed for CLIP and open CLIP
|
186 |
+
if 'positional_embedding' in visual_state_dict:
|
187 |
+
resize_visual_pos_embed(visual_state_dict, model)
|
188 |
+
# specified to EVA model
|
189 |
+
elif 'pos_embed' in visual_state_dict:
|
190 |
+
resize_eva_pos_embed(visual_state_dict, model)
|
191 |
+
|
192 |
+
visual_incompatible_keys = model.visual.load_state_dict(visual_state_dict, strict=strict)
|
193 |
+
logging.info(f"num of loaded visual_state_dict keys: {len(visual_state_dict.keys())}")
|
194 |
+
logging.info(f"visual_incompatible_keys.missing_keys: {visual_incompatible_keys.missing_keys}")
|
195 |
+
|
196 |
+
if text_checkpoint_path:
|
197 |
+
if text_tag == "eva_clip" or text_tag == "open_clip":
|
198 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=False, skip_list=skip_list)
|
199 |
+
elif text_tag == "clip":
|
200 |
+
text_state_dict = load_clip_text_state_dict(text_checkpoint_path, is_openai=True, skip_list=skip_list)
|
201 |
+
else:
|
202 |
+
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
|
203 |
+
|
204 |
+
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
|
205 |
+
|
206 |
+
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
|
207 |
+
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
|
208 |
+
|
209 |
+
return visual_incompatible_keys, text_incompatible_keys
|
210 |
+
|
211 |
+
def create_model(
|
212 |
+
model_name: str,
|
213 |
+
pretrained: Optional[str] = None,
|
214 |
+
precision: str = 'fp32',
|
215 |
+
device: Union[str, torch.device] = 'cpu',
|
216 |
+
jit: bool = False,
|
217 |
+
force_quick_gelu: bool = False,
|
218 |
+
force_custom_clip: bool = False,
|
219 |
+
force_patch_dropout: Optional[float] = None,
|
220 |
+
pretrained_image: str = '',
|
221 |
+
pretrained_text: str = '',
|
222 |
+
pretrained_hf: bool = True,
|
223 |
+
pretrained_visual_model: str = None,
|
224 |
+
pretrained_text_model: str = None,
|
225 |
+
cache_dir: Optional[str] = None,
|
226 |
+
skip_list: list = [],
|
227 |
+
):
|
228 |
+
model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
|
229 |
+
if isinstance(device, str):
|
230 |
+
device = torch.device(device)
|
231 |
+
|
232 |
+
if pretrained and pretrained.lower() == 'openai':
|
233 |
+
logging.info(f'Loading pretrained {model_name} from OpenAI.')
|
234 |
+
model = load_openai_model(
|
235 |
+
model_name,
|
236 |
+
precision=precision,
|
237 |
+
device=device,
|
238 |
+
jit=jit,
|
239 |
+
cache_dir=cache_dir,
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
model_cfg = get_model_config(model_name)
|
243 |
+
if model_cfg is not None:
|
244 |
+
logging.info(f'Loaded {model_name} model config.')
|
245 |
+
else:
|
246 |
+
logging.error(f'Model config for {model_name} not found; available models {list_models()}.')
|
247 |
+
raise RuntimeError(f'Model config for {model_name} not found.')
|
248 |
+
|
249 |
+
if 'rope' in model_cfg.get('vision_cfg', {}):
|
250 |
+
if model_cfg['vision_cfg']['rope']:
|
251 |
+
os.environ['RoPE'] = "1"
|
252 |
+
else:
|
253 |
+
os.environ['RoPE'] = "0"
|
254 |
+
|
255 |
+
if force_quick_gelu:
|
256 |
+
# override for use of QuickGELU on non-OpenAI transformer models
|
257 |
+
model_cfg["quick_gelu"] = True
|
258 |
+
|
259 |
+
if force_patch_dropout is not None:
|
260 |
+
# override the default patch dropout value
|
261 |
+
model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
|
262 |
+
|
263 |
+
cast_dtype = get_cast_dtype(precision)
|
264 |
+
custom_clip = model_cfg.pop('custom_text', False) or force_custom_clip or ('hf_model_name' in model_cfg['text_cfg'])
|
265 |
+
|
266 |
+
|
267 |
+
if custom_clip:
|
268 |
+
if 'hf_model_name' in model_cfg.get('text_cfg', {}):
|
269 |
+
model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf
|
270 |
+
model = CustomCLIP(**model_cfg, cast_dtype=cast_dtype)
|
271 |
+
else:
|
272 |
+
model = CLIP(**model_cfg, cast_dtype=cast_dtype)
|
273 |
+
|
274 |
+
pretrained_cfg = {}
|
275 |
+
if pretrained:
|
276 |
+
checkpoint_path = ''
|
277 |
+
pretrained_cfg = get_pretrained_cfg(model_name, pretrained)
|
278 |
+
if pretrained_cfg:
|
279 |
+
checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir)
|
280 |
+
elif os.path.exists(pretrained):
|
281 |
+
checkpoint_path = pretrained
|
282 |
+
|
283 |
+
if checkpoint_path:
|
284 |
+
logging.info(f'Loading pretrained {model_name} weights ({pretrained}).')
|
285 |
+
load_checkpoint(model,
|
286 |
+
checkpoint_path,
|
287 |
+
model_key="model|module|state_dict",
|
288 |
+
strict=False
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
error_str = (
|
292 |
+
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
|
293 |
+
f'Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.')
|
294 |
+
logging.warning(error_str)
|
295 |
+
raise RuntimeError(error_str)
|
296 |
+
else:
|
297 |
+
visual_checkpoint_path = ''
|
298 |
+
text_checkpoint_path = ''
|
299 |
+
|
300 |
+
if pretrained_image:
|
301 |
+
pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
|
302 |
+
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
|
303 |
+
if 'timm_model_name' in model_cfg.get('vision_cfg', {}):
|
304 |
+
# pretrained weight loading for timm models set via vision_cfg
|
305 |
+
model_cfg['vision_cfg']['timm_model_pretrained'] = True
|
306 |
+
elif pretrained_image_cfg:
|
307 |
+
visual_checkpoint_path = download_pretrained(pretrained_image_cfg, cache_dir=cache_dir)
|
308 |
+
elif os.path.exists(pretrained_image):
|
309 |
+
visual_checkpoint_path = pretrained_image
|
310 |
+
else:
|
311 |
+
logging.warning(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
|
312 |
+
raise RuntimeError(f'Pretrained weights ({visual_checkpoint_path}) not found for model {model_name}.visual.')
|
313 |
+
|
314 |
+
if pretrained_text:
|
315 |
+
pretrained_text_model = pretrained_text_model.replace('/', '-') # for callers using old naming with / in ViT names
|
316 |
+
pretrained_text_cfg = get_pretrained_cfg(pretrained_text_model, pretrained_text)
|
317 |
+
if pretrained_image_cfg:
|
318 |
+
text_checkpoint_path = download_pretrained(pretrained_text_cfg, cache_dir=cache_dir)
|
319 |
+
elif os.path.exists(pretrained_text):
|
320 |
+
text_checkpoint_path = pretrained_text
|
321 |
+
else:
|
322 |
+
logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
|
323 |
+
raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
|
324 |
+
|
325 |
+
if visual_checkpoint_path:
|
326 |
+
logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
|
327 |
+
if text_checkpoint_path:
|
328 |
+
logging.info(f'Loading pretrained {model_name}.text weights ({text_checkpoint_path}).')
|
329 |
+
|
330 |
+
if visual_checkpoint_path or text_checkpoint_path:
|
331 |
+
load_pretrained_checkpoint(
|
332 |
+
model,
|
333 |
+
visual_checkpoint_path,
|
334 |
+
text_checkpoint_path,
|
335 |
+
strict=False,
|
336 |
+
visual_model=pretrained_visual_model,
|
337 |
+
text_model=pretrained_text_model,
|
338 |
+
model_key="model|module|state_dict",
|
339 |
+
skip_list=skip_list
|
340 |
+
)
|
341 |
+
|
342 |
+
if "fp16" in precision or "bf16" in precision:
|
343 |
+
logging.info(f'convert precision to {precision}')
|
344 |
+
model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)
|
345 |
+
|
346 |
+
model.to(device=device)
|
347 |
+
|
348 |
+
# set image / mean metadata from pretrained_cfg if available, or use default
|
349 |
+
model.visual.image_mean = pretrained_cfg.get('mean', None) or OPENAI_DATASET_MEAN
|
350 |
+
model.visual.image_std = pretrained_cfg.get('std', None) or OPENAI_DATASET_STD
|
351 |
+
|
352 |
+
if jit:
|
353 |
+
model = torch.jit.script(model)
|
354 |
+
|
355 |
+
return model
|
356 |
+
|
357 |
+
|
358 |
+
def create_model_and_transforms(
|
359 |
+
model_name: str,
|
360 |
+
pretrained: Optional[str] = None,
|
361 |
+
precision: str = 'fp32',
|
362 |
+
device: Union[str, torch.device] = 'cpu',
|
363 |
+
jit: bool = False,
|
364 |
+
force_quick_gelu: bool = False,
|
365 |
+
force_custom_clip: bool = False,
|
366 |
+
force_patch_dropout: Optional[float] = None,
|
367 |
+
pretrained_image: str = '',
|
368 |
+
pretrained_text: str = '',
|
369 |
+
pretrained_hf: bool = True,
|
370 |
+
pretrained_visual_model: str = None,
|
371 |
+
pretrained_text_model: str = None,
|
372 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
373 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
374 |
+
cache_dir: Optional[str] = None,
|
375 |
+
skip_list: list = [],
|
376 |
+
):
|
377 |
+
model = create_model(
|
378 |
+
model_name,
|
379 |
+
pretrained,
|
380 |
+
precision=precision,
|
381 |
+
device=device,
|
382 |
+
jit=jit,
|
383 |
+
force_quick_gelu=force_quick_gelu,
|
384 |
+
force_custom_clip=force_custom_clip,
|
385 |
+
force_patch_dropout=force_patch_dropout,
|
386 |
+
pretrained_image=pretrained_image,
|
387 |
+
pretrained_text=pretrained_text,
|
388 |
+
pretrained_hf=pretrained_hf,
|
389 |
+
pretrained_visual_model=pretrained_visual_model,
|
390 |
+
pretrained_text_model=pretrained_text_model,
|
391 |
+
cache_dir=cache_dir,
|
392 |
+
skip_list=skip_list,
|
393 |
+
)
|
394 |
+
|
395 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
396 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
397 |
+
preprocess_train = image_transform(
|
398 |
+
model.visual.image_size,
|
399 |
+
is_train=True,
|
400 |
+
mean=image_mean,
|
401 |
+
std=image_std
|
402 |
+
)
|
403 |
+
preprocess_val = image_transform(
|
404 |
+
model.visual.image_size,
|
405 |
+
is_train=False,
|
406 |
+
mean=image_mean,
|
407 |
+
std=image_std
|
408 |
+
)
|
409 |
+
|
410 |
+
return model, preprocess_train, preprocess_val
|
411 |
+
|
412 |
+
|
413 |
+
def create_transforms(
|
414 |
+
model_name: str,
|
415 |
+
pretrained: Optional[str] = None,
|
416 |
+
precision: str = 'fp32',
|
417 |
+
device: Union[str, torch.device] = 'cpu',
|
418 |
+
jit: bool = False,
|
419 |
+
force_quick_gelu: bool = False,
|
420 |
+
force_custom_clip: bool = False,
|
421 |
+
force_patch_dropout: Optional[float] = None,
|
422 |
+
pretrained_image: str = '',
|
423 |
+
pretrained_text: str = '',
|
424 |
+
pretrained_hf: bool = True,
|
425 |
+
pretrained_visual_model: str = None,
|
426 |
+
pretrained_text_model: str = None,
|
427 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
428 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
429 |
+
cache_dir: Optional[str] = None,
|
430 |
+
skip_list: list = [],
|
431 |
+
):
|
432 |
+
model = create_model(
|
433 |
+
model_name,
|
434 |
+
pretrained,
|
435 |
+
precision=precision,
|
436 |
+
device=device,
|
437 |
+
jit=jit,
|
438 |
+
force_quick_gelu=force_quick_gelu,
|
439 |
+
force_custom_clip=force_custom_clip,
|
440 |
+
force_patch_dropout=force_patch_dropout,
|
441 |
+
pretrained_image=pretrained_image,
|
442 |
+
pretrained_text=pretrained_text,
|
443 |
+
pretrained_hf=pretrained_hf,
|
444 |
+
pretrained_visual_model=pretrained_visual_model,
|
445 |
+
pretrained_text_model=pretrained_text_model,
|
446 |
+
cache_dir=cache_dir,
|
447 |
+
skip_list=skip_list,
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
452 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
453 |
+
preprocess_train = image_transform(
|
454 |
+
model.visual.image_size,
|
455 |
+
is_train=True,
|
456 |
+
mean=image_mean,
|
457 |
+
std=image_std
|
458 |
+
)
|
459 |
+
preprocess_val = image_transform(
|
460 |
+
model.visual.image_size,
|
461 |
+
is_train=False,
|
462 |
+
mean=image_mean,
|
463 |
+
std=image_std
|
464 |
+
)
|
465 |
+
del model
|
466 |
+
|
467 |
+
return preprocess_train, preprocess_val
|
468 |
+
|
469 |
+
def create_model_from_pretrained(
|
470 |
+
model_name: str,
|
471 |
+
pretrained: str,
|
472 |
+
precision: str = 'fp32',
|
473 |
+
device: Union[str, torch.device] = 'cpu',
|
474 |
+
jit: bool = False,
|
475 |
+
force_quick_gelu: bool = False,
|
476 |
+
force_custom_clip: bool = False,
|
477 |
+
force_patch_dropout: Optional[float] = None,
|
478 |
+
return_transform: bool = True,
|
479 |
+
image_mean: Optional[Tuple[float, ...]] = None,
|
480 |
+
image_std: Optional[Tuple[float, ...]] = None,
|
481 |
+
cache_dir: Optional[str] = None,
|
482 |
+
is_frozen: bool = False,
|
483 |
+
):
|
484 |
+
if not is_pretrained_cfg(model_name, pretrained) and not os.path.exists(pretrained):
|
485 |
+
raise RuntimeError(
|
486 |
+
f'{pretrained} is not a valid pretrained cfg or checkpoint for {model_name}.'
|
487 |
+
f' Use open_clip.list_pretrained() to find one.')
|
488 |
+
|
489 |
+
model = create_model(
|
490 |
+
model_name,
|
491 |
+
pretrained,
|
492 |
+
precision=precision,
|
493 |
+
device=device,
|
494 |
+
jit=jit,
|
495 |
+
force_quick_gelu=force_quick_gelu,
|
496 |
+
force_custom_clip=force_custom_clip,
|
497 |
+
force_patch_dropout=force_patch_dropout,
|
498 |
+
cache_dir=cache_dir,
|
499 |
+
)
|
500 |
+
|
501 |
+
if is_frozen:
|
502 |
+
for param in model.parameters():
|
503 |
+
param.requires_grad = False
|
504 |
+
|
505 |
+
if not return_transform:
|
506 |
+
return model
|
507 |
+
|
508 |
+
image_mean = image_mean or getattr(model.visual, 'image_mean', None)
|
509 |
+
image_std = image_std or getattr(model.visual, 'image_std', None)
|
510 |
+
preprocess = image_transform(
|
511 |
+
model.visual.image_size,
|
512 |
+
is_train=False,
|
513 |
+
mean=image_mean,
|
514 |
+
std=image_std
|
515 |
+
)
|
516 |
+
|
517 |
+
return model, preprocess
|
eva_clip/hf_configs.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# HF architecture dict:
|
2 |
+
arch_dict = {
|
3 |
+
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
4 |
+
"roberta": {
|
5 |
+
"config_names": {
|
6 |
+
"context_length": "max_position_embeddings",
|
7 |
+
"vocab_size": "vocab_size",
|
8 |
+
"width": "hidden_size",
|
9 |
+
"heads": "num_attention_heads",
|
10 |
+
"layers": "num_hidden_layers",
|
11 |
+
"layer_attr": "layer",
|
12 |
+
"token_embeddings_attr": "embeddings"
|
13 |
+
},
|
14 |
+
"pooler": "mean_pooler",
|
15 |
+
},
|
16 |
+
# https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
|
17 |
+
"xlm-roberta": {
|
18 |
+
"config_names": {
|
19 |
+
"context_length": "max_position_embeddings",
|
20 |
+
"vocab_size": "vocab_size",
|
21 |
+
"width": "hidden_size",
|
22 |
+
"heads": "num_attention_heads",
|
23 |
+
"layers": "num_hidden_layers",
|
24 |
+
"layer_attr": "layer",
|
25 |
+
"token_embeddings_attr": "embeddings"
|
26 |
+
},
|
27 |
+
"pooler": "mean_pooler",
|
28 |
+
},
|
29 |
+
# https://huggingface.co/docs/transformers/model_doc/mt5#mt5
|
30 |
+
"mt5": {
|
31 |
+
"config_names": {
|
32 |
+
# unlimited seqlen
|
33 |
+
# https://github.com/google-research/text-to-text-transfer-transformer/issues/273
|
34 |
+
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
|
35 |
+
"context_length": "",
|
36 |
+
"vocab_size": "vocab_size",
|
37 |
+
"width": "d_model",
|
38 |
+
"heads": "num_heads",
|
39 |
+
"layers": "num_layers",
|
40 |
+
"layer_attr": "block",
|
41 |
+
"token_embeddings_attr": "embed_tokens"
|
42 |
+
},
|
43 |
+
"pooler": "mean_pooler",
|
44 |
+
},
|
45 |
+
"bert": {
|
46 |
+
"config_names": {
|
47 |
+
"context_length": "max_position_embeddings",
|
48 |
+
"vocab_size": "vocab_size",
|
49 |
+
"width": "hidden_size",
|
50 |
+
"heads": "num_attention_heads",
|
51 |
+
"layers": "num_hidden_layers",
|
52 |
+
"layer_attr": "layer",
|
53 |
+
"token_embeddings_attr": "embeddings"
|
54 |
+
},
|
55 |
+
"pooler": "mean_pooler",
|
56 |
+
}
|
57 |
+
}
|
eva_clip/hf_model.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" huggingface model adapter
|
2 |
+
|
3 |
+
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import re
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torch import TensorType
|
12 |
+
try:
|
13 |
+
import transformers
|
14 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
|
15 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
|
16 |
+
BaseModelOutputWithPoolingAndCrossAttentions
|
17 |
+
except ImportError as e:
|
18 |
+
transformers = None
|
19 |
+
|
20 |
+
|
21 |
+
class BaseModelOutput:
|
22 |
+
pass
|
23 |
+
|
24 |
+
|
25 |
+
class PretrainedConfig:
|
26 |
+
pass
|
27 |
+
|
28 |
+
from .hf_configs import arch_dict
|
29 |
+
|
30 |
+
# utils
|
31 |
+
def _camel2snake(s):
|
32 |
+
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
|
33 |
+
|
34 |
+
# TODO: ?last - for gpt-like models
|
35 |
+
_POOLERS = {}
|
36 |
+
|
37 |
+
def register_pooler(cls):
|
38 |
+
"""Decorator registering pooler class"""
|
39 |
+
_POOLERS[_camel2snake(cls.__name__)] = cls
|
40 |
+
return cls
|
41 |
+
|
42 |
+
|
43 |
+
@register_pooler
|
44 |
+
class MeanPooler(nn.Module):
|
45 |
+
"""Mean pooling"""
|
46 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
47 |
+
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
|
48 |
+
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
|
49 |
+
|
50 |
+
@register_pooler
|
51 |
+
class MaxPooler(nn.Module):
|
52 |
+
"""Max pooling"""
|
53 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
54 |
+
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
|
55 |
+
return masked_output.max(1).values
|
56 |
+
|
57 |
+
@register_pooler
|
58 |
+
class ClsPooler(nn.Module):
|
59 |
+
"""CLS token pooling"""
|
60 |
+
def __init__(self, use_pooler_output=True):
|
61 |
+
super().__init__()
|
62 |
+
self.cls_token_position = 0
|
63 |
+
self.use_pooler_output = use_pooler_output
|
64 |
+
|
65 |
+
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
|
66 |
+
|
67 |
+
if (self.use_pooler_output and
|
68 |
+
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
|
69 |
+
(x.pooler_output is not None)
|
70 |
+
):
|
71 |
+
return x.pooler_output
|
72 |
+
|
73 |
+
return x.last_hidden_state[:, self.cls_token_position, :]
|
74 |
+
|
75 |
+
class HFTextEncoder(nn.Module):
|
76 |
+
"""HuggingFace model adapter"""
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
model_name_or_path: str,
|
80 |
+
output_dim: int,
|
81 |
+
tokenizer_name: str = None,
|
82 |
+
config: PretrainedConfig = None,
|
83 |
+
pooler_type: str = None,
|
84 |
+
proj: str = None,
|
85 |
+
pretrained: bool = True,
|
86 |
+
masked_language_modeling: bool = False):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.output_dim = output_dim
|
90 |
+
|
91 |
+
# TODO: find better way to get this information
|
92 |
+
uses_transformer_pooler = (pooler_type == "cls_pooler")
|
93 |
+
|
94 |
+
if transformers is None:
|
95 |
+
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
|
96 |
+
if config is None:
|
97 |
+
self.config = AutoConfig.from_pretrained(model_name_or_path)
|
98 |
+
if masked_language_modeling:
|
99 |
+
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
|
100 |
+
AutoModelForMaskedLM.from_config, self.config)
|
101 |
+
else:
|
102 |
+
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
|
103 |
+
AutoModel.from_config, self.config)
|
104 |
+
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
|
105 |
+
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
|
106 |
+
self.transformer = create_func(model_args)
|
107 |
+
self.transformer = self.transformer.encoder
|
108 |
+
else:
|
109 |
+
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
|
110 |
+
else:
|
111 |
+
self.config = config
|
112 |
+
if masked_language_modeling:
|
113 |
+
self.transformer = AutoModelForMaskedLM.from_config(config)
|
114 |
+
else:
|
115 |
+
self.transformer = AutoModel.from_config(config)
|
116 |
+
|
117 |
+
if pooler_type is None: # get default arch pooler
|
118 |
+
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
|
119 |
+
else:
|
120 |
+
self.pooler = _POOLERS[pooler_type]()
|
121 |
+
|
122 |
+
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
|
123 |
+
if (d_model == output_dim) and (proj is None): # do we always need a proj?
|
124 |
+
self.proj = nn.Identity()
|
125 |
+
elif proj == 'linear':
|
126 |
+
self.proj = nn.Linear(d_model, output_dim, bias=False)
|
127 |
+
elif proj == 'mlp':
|
128 |
+
hidden_size = (d_model + output_dim) // 2
|
129 |
+
self.proj = nn.Sequential(
|
130 |
+
nn.Linear(d_model, hidden_size, bias=False),
|
131 |
+
nn.GELU(),
|
132 |
+
nn.Linear(hidden_size, output_dim, bias=False),
|
133 |
+
)
|
134 |
+
|
135 |
+
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
|
136 |
+
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
|
137 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
138 |
+
|
139 |
+
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
|
140 |
+
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
|
141 |
+
# attn_mask = (x != self.config.pad_token_id).long()
|
142 |
+
# out = self.transformer(
|
143 |
+
# input_ids=x,
|
144 |
+
# attention_mask=attn_mask,
|
145 |
+
# encoder_hidden_states = image_embeds,
|
146 |
+
# encoder_attention_mask = image_atts,
|
147 |
+
# )
|
148 |
+
# pooled_out = self.pooler(out, attn_mask)
|
149 |
+
|
150 |
+
# return self.itm_proj(pooled_out)
|
151 |
+
|
152 |
+
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
|
153 |
+
if masked_indices is None:
|
154 |
+
masked_indices = torch.bernoulli(probability_matrix).bool()
|
155 |
+
|
156 |
+
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
|
157 |
+
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
|
158 |
+
|
159 |
+
if targets is not None:
|
160 |
+
targets[~masked_indices] = -100 # We only compute loss on masked tokens
|
161 |
+
|
162 |
+
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
|
163 |
+
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
|
164 |
+
input_ids[indices_replaced] = self.tokenizer.mask_token_id
|
165 |
+
|
166 |
+
# 10% of the time, we replace masked input tokens with random word
|
167 |
+
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
|
168 |
+
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
|
169 |
+
input_ids[indices_random] = random_words[indices_random]
|
170 |
+
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
|
171 |
+
|
172 |
+
if targets is not None:
|
173 |
+
return input_ids, targets
|
174 |
+
else:
|
175 |
+
return input_ids
|
176 |
+
|
177 |
+
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
|
178 |
+
labels = input_ids.clone()
|
179 |
+
attn_mask = (input_ids != self.config.pad_token_id).long()
|
180 |
+
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
|
181 |
+
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
|
182 |
+
probability_matrix = torch.full(labels.shape, mlm_probability)
|
183 |
+
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
|
184 |
+
probability_matrix = probability_matrix)
|
185 |
+
mlm_output = self.transformer(input_ids,
|
186 |
+
attention_mask = attn_mask,
|
187 |
+
encoder_hidden_states = image_embeds,
|
188 |
+
encoder_attention_mask = image_atts,
|
189 |
+
return_dict = True,
|
190 |
+
labels = labels,
|
191 |
+
)
|
192 |
+
return mlm_output.loss
|
193 |
+
# mlm_output = self.transformer(input_ids,
|
194 |
+
# attention_mask = attn_mask,
|
195 |
+
# encoder_hidden_states = image_embeds,
|
196 |
+
# encoder_attention_mask = image_atts,
|
197 |
+
# return_dict = True,
|
198 |
+
# ).last_hidden_state
|
199 |
+
# logits = self.mlm_proj(mlm_output)
|
200 |
+
|
201 |
+
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
|
202 |
+
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
|
203 |
+
# labels = labels[:, 1:].contiguous().view(-1)
|
204 |
+
|
205 |
+
# mlm_loss = F.cross_entropy(
|
206 |
+
# logits,
|
207 |
+
# labels,
|
208 |
+
# # label_smoothing=0.1,
|
209 |
+
# )
|
210 |
+
# return mlm_loss
|
211 |
+
|
212 |
+
|
213 |
+
def forward(self, x:TensorType) -> TensorType:
|
214 |
+
attn_mask = (x != self.config.pad_token_id).long()
|
215 |
+
out = self.transformer(input_ids=x, attention_mask=attn_mask)
|
216 |
+
pooled_out = self.pooler(out, attn_mask)
|
217 |
+
|
218 |
+
return self.proj(pooled_out)
|
219 |
+
|
220 |
+
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
221 |
+
if not unlocked_layers: # full freezing
|
222 |
+
for n, p in self.transformer.named_parameters():
|
223 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
224 |
+
return
|
225 |
+
|
226 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
227 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
228 |
+
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
|
229 |
+
embeddings = getattr(
|
230 |
+
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
|
231 |
+
modules = [embeddings, *layer_list][:-unlocked_layers]
|
232 |
+
# freeze layers
|
233 |
+
for module in modules:
|
234 |
+
for n, p in module.named_parameters():
|
235 |
+
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
|
236 |
+
|
237 |
+
|
238 |
+
@torch.jit.ignore
|
239 |
+
def set_grad_checkpointing(self, enable=True):
|
240 |
+
self.transformer.gradient_checkpointing_enable()
|
241 |
+
|
242 |
+
def get_num_layers(self):
|
243 |
+
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
|
244 |
+
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
|
245 |
+
return len(layer_list)
|
246 |
+
|
247 |
+
def init_parameters(self):
|
248 |
+
pass
|
eva_clip/loss.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
try:
|
7 |
+
import torch.distributed.nn
|
8 |
+
from torch import distributed as dist
|
9 |
+
has_distributed = True
|
10 |
+
except ImportError:
|
11 |
+
has_distributed = False
|
12 |
+
|
13 |
+
try:
|
14 |
+
import horovod.torch as hvd
|
15 |
+
except ImportError:
|
16 |
+
hvd = None
|
17 |
+
|
18 |
+
from timm.loss import LabelSmoothingCrossEntropy
|
19 |
+
|
20 |
+
|
21 |
+
def gather_features(
|
22 |
+
image_features,
|
23 |
+
text_features,
|
24 |
+
local_loss=False,
|
25 |
+
gather_with_grad=False,
|
26 |
+
rank=0,
|
27 |
+
world_size=1,
|
28 |
+
use_horovod=False
|
29 |
+
):
|
30 |
+
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
|
31 |
+
if use_horovod:
|
32 |
+
assert hvd is not None, 'Please install horovod'
|
33 |
+
if gather_with_grad:
|
34 |
+
all_image_features = hvd.allgather(image_features)
|
35 |
+
all_text_features = hvd.allgather(text_features)
|
36 |
+
else:
|
37 |
+
with torch.no_grad():
|
38 |
+
all_image_features = hvd.allgather(image_features)
|
39 |
+
all_text_features = hvd.allgather(text_features)
|
40 |
+
if not local_loss:
|
41 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
42 |
+
gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
|
43 |
+
gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
|
44 |
+
gathered_image_features[rank] = image_features
|
45 |
+
gathered_text_features[rank] = text_features
|
46 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
47 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
48 |
+
else:
|
49 |
+
# We gather tensors from all gpus
|
50 |
+
if gather_with_grad:
|
51 |
+
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
|
52 |
+
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
|
53 |
+
# all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
|
54 |
+
# all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
|
55 |
+
else:
|
56 |
+
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
|
57 |
+
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
|
58 |
+
dist.all_gather(gathered_image_features, image_features)
|
59 |
+
dist.all_gather(gathered_text_features, text_features)
|
60 |
+
if not local_loss:
|
61 |
+
# ensure grads for local rank when all_* features don't have a gradient
|
62 |
+
gathered_image_features[rank] = image_features
|
63 |
+
gathered_text_features[rank] = text_features
|
64 |
+
all_image_features = torch.cat(gathered_image_features, dim=0)
|
65 |
+
all_text_features = torch.cat(gathered_text_features, dim=0)
|
66 |
+
|
67 |
+
return all_image_features, all_text_features
|
68 |
+
|
69 |
+
|
70 |
+
class ClipLoss(nn.Module):
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
local_loss=False,
|
75 |
+
gather_with_grad=False,
|
76 |
+
cache_labels=False,
|
77 |
+
rank=0,
|
78 |
+
world_size=1,
|
79 |
+
use_horovod=False,
|
80 |
+
smoothing=0.,
|
81 |
+
):
|
82 |
+
super().__init__()
|
83 |
+
self.local_loss = local_loss
|
84 |
+
self.gather_with_grad = gather_with_grad
|
85 |
+
self.cache_labels = cache_labels
|
86 |
+
self.rank = rank
|
87 |
+
self.world_size = world_size
|
88 |
+
self.use_horovod = use_horovod
|
89 |
+
self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
|
90 |
+
|
91 |
+
# cache state
|
92 |
+
self.prev_num_logits = 0
|
93 |
+
self.labels = {}
|
94 |
+
|
95 |
+
def forward(self, image_features, text_features, logit_scale=1.):
|
96 |
+
device = image_features.device
|
97 |
+
if self.world_size > 1:
|
98 |
+
all_image_features, all_text_features = gather_features(
|
99 |
+
image_features, text_features,
|
100 |
+
self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
|
101 |
+
|
102 |
+
if self.local_loss:
|
103 |
+
logits_per_image = logit_scale * image_features @ all_text_features.T
|
104 |
+
logits_per_text = logit_scale * text_features @ all_image_features.T
|
105 |
+
else:
|
106 |
+
logits_per_image = logit_scale * all_image_features @ all_text_features.T
|
107 |
+
logits_per_text = logits_per_image.T
|
108 |
+
else:
|
109 |
+
logits_per_image = logit_scale * image_features @ text_features.T
|
110 |
+
logits_per_text = logit_scale * text_features @ image_features.T
|
111 |
+
# calculated ground-truth and cache if enabled
|
112 |
+
num_logits = logits_per_image.shape[0]
|
113 |
+
if self.prev_num_logits != num_logits or device not in self.labels:
|
114 |
+
labels = torch.arange(num_logits, device=device, dtype=torch.long)
|
115 |
+
if self.world_size > 1 and self.local_loss:
|
116 |
+
labels = labels + num_logits * self.rank
|
117 |
+
if self.cache_labels:
|
118 |
+
self.labels[device] = labels
|
119 |
+
self.prev_num_logits = num_logits
|
120 |
+
else:
|
121 |
+
labels = self.labels[device]
|
122 |
+
|
123 |
+
if self.label_smoothing_cross_entropy:
|
124 |
+
total_loss = (
|
125 |
+
self.label_smoothing_cross_entropy(logits_per_image, labels) +
|
126 |
+
self.label_smoothing_cross_entropy(logits_per_text, labels)
|
127 |
+
) / 2
|
128 |
+
else:
|
129 |
+
total_loss = (
|
130 |
+
F.cross_entropy(logits_per_image, labels) +
|
131 |
+
F.cross_entropy(logits_per_text, labels)
|
132 |
+
) / 2
|
133 |
+
|
134 |
+
acc = None
|
135 |
+
i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
|
136 |
+
t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
|
137 |
+
acc = {"i2t": i2t_acc, "t2i": t2i_acc}
|
138 |
+
return total_loss, acc
|
eva_clip/model.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP Model
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Optional, Tuple, Union
|
8 |
+
from functools import partial
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
try:
|
16 |
+
from .hf_model import HFTextEncoder
|
17 |
+
except:
|
18 |
+
HFTextEncoder = None
|
19 |
+
from .modified_resnet import ModifiedResNet
|
20 |
+
from .timm_model import TimmModel
|
21 |
+
from .eva_vit_model import EVAVisionTransformer
|
22 |
+
from .transformer import LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
|
23 |
+
|
24 |
+
try:
|
25 |
+
from apex.normalization import FusedLayerNorm
|
26 |
+
except:
|
27 |
+
FusedLayerNorm = LayerNorm
|
28 |
+
print("Please 'pip install apex'")
|
29 |
+
|
30 |
+
try:
|
31 |
+
import xformers.ops as xops
|
32 |
+
except ImportError:
|
33 |
+
xops = None
|
34 |
+
print("Please 'pip install xformers'")
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class CLIPVisionCfg:
|
38 |
+
layers: Union[Tuple[int, int, int, int], int] = 12
|
39 |
+
width: int = 768
|
40 |
+
head_width: int = 64
|
41 |
+
mlp_ratio: float = 4.0
|
42 |
+
patch_size: int = 16
|
43 |
+
image_size: Union[Tuple[int, int], int] = 224
|
44 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
45 |
+
patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
|
46 |
+
global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
|
47 |
+
drop_path_rate: Optional[float] = None # drop path rate
|
48 |
+
timm_model_name: str = None # a valid model name overrides layers, width, patch_size
|
49 |
+
timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
|
50 |
+
timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
|
51 |
+
timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
|
52 |
+
timm_proj_bias: bool = False # enable bias final projection
|
53 |
+
eva_model_name: str = None # a valid eva model name overrides layers, width, patch_size
|
54 |
+
qkv_bias: bool = True
|
55 |
+
fusedLN: bool = False
|
56 |
+
xattn: bool = False
|
57 |
+
postnorm: bool = False
|
58 |
+
rope: bool = False
|
59 |
+
pt_hw_seq_len: int = 16 # 224/14
|
60 |
+
intp_freq: bool = False
|
61 |
+
naiveswiglu: bool = False
|
62 |
+
subln: bool = False
|
63 |
+
|
64 |
+
|
65 |
+
@dataclass
|
66 |
+
class CLIPTextCfg:
|
67 |
+
context_length: int = 77
|
68 |
+
vocab_size: int = 49408
|
69 |
+
width: int = 512
|
70 |
+
heads: int = 8
|
71 |
+
layers: int = 12
|
72 |
+
ls_init_value: Optional[float] = None # layer scale initial value
|
73 |
+
hf_model_name: str = None
|
74 |
+
hf_tokenizer_name: str = None
|
75 |
+
hf_model_pretrained: bool = True
|
76 |
+
proj: str = 'mlp'
|
77 |
+
pooler_type: str = 'mean_pooler'
|
78 |
+
masked_language_modeling: bool = False
|
79 |
+
fusedLN: bool = False
|
80 |
+
xattn: bool = False
|
81 |
+
attn_mask: bool = True
|
82 |
+
|
83 |
+
def get_cast_dtype(precision: str):
|
84 |
+
cast_dtype = None
|
85 |
+
if precision == 'bf16':
|
86 |
+
cast_dtype = torch.bfloat16
|
87 |
+
elif precision == 'fp16':
|
88 |
+
cast_dtype = torch.float16
|
89 |
+
return cast_dtype
|
90 |
+
|
91 |
+
|
92 |
+
def _build_vision_tower(
|
93 |
+
embed_dim: int,
|
94 |
+
vision_cfg: CLIPVisionCfg,
|
95 |
+
quick_gelu: bool = False,
|
96 |
+
cast_dtype: Optional[torch.dtype] = None
|
97 |
+
):
|
98 |
+
if isinstance(vision_cfg, dict):
|
99 |
+
vision_cfg = CLIPVisionCfg(**vision_cfg)
|
100 |
+
|
101 |
+
# OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
|
102 |
+
# memory efficient in recent PyTorch releases (>= 1.10).
|
103 |
+
# NOTE: timm models always use native GELU regardless of quick_gelu flag.
|
104 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
105 |
+
|
106 |
+
if vision_cfg.eva_model_name:
|
107 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
108 |
+
norm_layer = LayerNorm
|
109 |
+
|
110 |
+
visual = EVAVisionTransformer(
|
111 |
+
img_size=vision_cfg.image_size,
|
112 |
+
patch_size=vision_cfg.patch_size,
|
113 |
+
num_classes=embed_dim,
|
114 |
+
use_mean_pooling=vision_cfg.global_average_pool, #False
|
115 |
+
init_values=vision_cfg.ls_init_value,
|
116 |
+
patch_dropout=vision_cfg.patch_dropout,
|
117 |
+
embed_dim=vision_cfg.width,
|
118 |
+
depth=vision_cfg.layers,
|
119 |
+
num_heads=vision_heads,
|
120 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
121 |
+
qkv_bias=vision_cfg.qkv_bias,
|
122 |
+
drop_path_rate=vision_cfg.drop_path_rate,
|
123 |
+
norm_layer= partial(FusedLayerNorm, eps=1e-6) if vision_cfg.fusedLN else partial(norm_layer, eps=1e-6),
|
124 |
+
xattn=vision_cfg.xattn,
|
125 |
+
rope=vision_cfg.rope,
|
126 |
+
postnorm=vision_cfg.postnorm,
|
127 |
+
pt_hw_seq_len= vision_cfg.pt_hw_seq_len, # 224/14
|
128 |
+
intp_freq= vision_cfg.intp_freq,
|
129 |
+
naiveswiglu= vision_cfg.naiveswiglu,
|
130 |
+
subln= vision_cfg.subln
|
131 |
+
)
|
132 |
+
elif vision_cfg.timm_model_name:
|
133 |
+
visual = TimmModel(
|
134 |
+
vision_cfg.timm_model_name,
|
135 |
+
pretrained=vision_cfg.timm_model_pretrained,
|
136 |
+
pool=vision_cfg.timm_pool,
|
137 |
+
proj=vision_cfg.timm_proj,
|
138 |
+
proj_bias=vision_cfg.timm_proj_bias,
|
139 |
+
embed_dim=embed_dim,
|
140 |
+
image_size=vision_cfg.image_size
|
141 |
+
)
|
142 |
+
act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models
|
143 |
+
elif isinstance(vision_cfg.layers, (tuple, list)):
|
144 |
+
vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
|
145 |
+
visual = ModifiedResNet(
|
146 |
+
layers=vision_cfg.layers,
|
147 |
+
output_dim=embed_dim,
|
148 |
+
heads=vision_heads,
|
149 |
+
image_size=vision_cfg.image_size,
|
150 |
+
width=vision_cfg.width
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
vision_heads = vision_cfg.width // vision_cfg.head_width
|
154 |
+
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
|
155 |
+
visual = VisionTransformer(
|
156 |
+
image_size=vision_cfg.image_size,
|
157 |
+
patch_size=vision_cfg.patch_size,
|
158 |
+
width=vision_cfg.width,
|
159 |
+
layers=vision_cfg.layers,
|
160 |
+
heads=vision_heads,
|
161 |
+
mlp_ratio=vision_cfg.mlp_ratio,
|
162 |
+
ls_init_value=vision_cfg.ls_init_value,
|
163 |
+
patch_dropout=vision_cfg.patch_dropout,
|
164 |
+
global_average_pool=vision_cfg.global_average_pool,
|
165 |
+
output_dim=embed_dim,
|
166 |
+
act_layer=act_layer,
|
167 |
+
norm_layer=norm_layer,
|
168 |
+
)
|
169 |
+
|
170 |
+
return visual
|
171 |
+
|
172 |
+
|
173 |
+
def _build_text_tower(
|
174 |
+
embed_dim: int,
|
175 |
+
text_cfg: CLIPTextCfg,
|
176 |
+
quick_gelu: bool = False,
|
177 |
+
cast_dtype: Optional[torch.dtype] = None,
|
178 |
+
):
|
179 |
+
if isinstance(text_cfg, dict):
|
180 |
+
text_cfg = CLIPTextCfg(**text_cfg)
|
181 |
+
|
182 |
+
if text_cfg.hf_model_name:
|
183 |
+
text = HFTextEncoder(
|
184 |
+
text_cfg.hf_model_name,
|
185 |
+
output_dim=embed_dim,
|
186 |
+
tokenizer_name=text_cfg.hf_tokenizer_name,
|
187 |
+
proj=text_cfg.proj,
|
188 |
+
pooler_type=text_cfg.pooler_type,
|
189 |
+
masked_language_modeling=text_cfg.masked_language_modeling
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
act_layer = QuickGELU if quick_gelu else nn.GELU
|
193 |
+
norm_layer = LayerNorm
|
194 |
+
|
195 |
+
text = TextTransformer(
|
196 |
+
context_length=text_cfg.context_length,
|
197 |
+
vocab_size=text_cfg.vocab_size,
|
198 |
+
width=text_cfg.width,
|
199 |
+
heads=text_cfg.heads,
|
200 |
+
layers=text_cfg.layers,
|
201 |
+
ls_init_value=text_cfg.ls_init_value,
|
202 |
+
output_dim=embed_dim,
|
203 |
+
act_layer=act_layer,
|
204 |
+
norm_layer= FusedLayerNorm if text_cfg.fusedLN else norm_layer,
|
205 |
+
xattn=text_cfg.xattn,
|
206 |
+
attn_mask=text_cfg.attn_mask,
|
207 |
+
)
|
208 |
+
return text
|
209 |
+
|
210 |
+
class CLIP(nn.Module):
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
embed_dim: int,
|
214 |
+
vision_cfg: CLIPVisionCfg,
|
215 |
+
text_cfg: CLIPTextCfg,
|
216 |
+
quick_gelu: bool = False,
|
217 |
+
cast_dtype: Optional[torch.dtype] = None,
|
218 |
+
):
|
219 |
+
super().__init__()
|
220 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
221 |
+
|
222 |
+
text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
223 |
+
self.transformer = text.transformer
|
224 |
+
self.vocab_size = text.vocab_size
|
225 |
+
self.token_embedding = text.token_embedding
|
226 |
+
self.positional_embedding = text.positional_embedding
|
227 |
+
self.ln_final = text.ln_final
|
228 |
+
self.text_projection = text.text_projection
|
229 |
+
self.register_buffer('attn_mask', text.attn_mask, persistent=False)
|
230 |
+
|
231 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
232 |
+
|
233 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
234 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
235 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
236 |
+
|
237 |
+
@torch.jit.ignore
|
238 |
+
def set_grad_checkpointing(self, enable=True):
|
239 |
+
self.visual.set_grad_checkpointing(enable)
|
240 |
+
self.transformer.grad_checkpointing = enable
|
241 |
+
|
242 |
+
@torch.jit.ignore
|
243 |
+
def no_weight_decay(self):
|
244 |
+
return {'logit_scale'}
|
245 |
+
|
246 |
+
def encode_image(self, image, normalize: bool = False):
|
247 |
+
features = self.visual(image)
|
248 |
+
return F.normalize(features, dim=-1) if normalize else features
|
249 |
+
|
250 |
+
def encode_text(self, text, normalize: bool = False):
|
251 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
252 |
+
|
253 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
254 |
+
|
255 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
256 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
257 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
258 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
259 |
+
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
|
260 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
261 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
262 |
+
return F.normalize(x, dim=-1) if normalize else x
|
263 |
+
|
264 |
+
def forward(self, image, text):
|
265 |
+
image_features = self.encode_image(image, normalize=True)
|
266 |
+
text_features = self.encode_text(text, normalize=True)
|
267 |
+
return image_features, text_features, self.logit_scale.exp()
|
268 |
+
|
269 |
+
|
270 |
+
class CustomCLIP(nn.Module):
|
271 |
+
def __init__(
|
272 |
+
self,
|
273 |
+
embed_dim: int,
|
274 |
+
vision_cfg: CLIPVisionCfg,
|
275 |
+
text_cfg: CLIPTextCfg,
|
276 |
+
quick_gelu: bool = False,
|
277 |
+
cast_dtype: Optional[torch.dtype] = None,
|
278 |
+
itm_task: bool = False,
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
|
282 |
+
self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
|
283 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
284 |
+
|
285 |
+
def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
|
286 |
+
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
|
287 |
+
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
|
288 |
+
|
289 |
+
def lock_text_tower(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
|
290 |
+
self.text.lock(unlocked_layers, freeze_layer_norm)
|
291 |
+
|
292 |
+
@torch.jit.ignore
|
293 |
+
def set_grad_checkpointing(self, enable=True):
|
294 |
+
self.visual.set_grad_checkpointing(enable)
|
295 |
+
self.text.set_grad_checkpointing(enable)
|
296 |
+
|
297 |
+
@torch.jit.ignore
|
298 |
+
def no_weight_decay(self):
|
299 |
+
return {'logit_scale'}
|
300 |
+
|
301 |
+
def encode_image(self, image, normalize: bool = False):
|
302 |
+
features = self.visual(image)
|
303 |
+
return F.normalize(features, dim=-1) if normalize else features
|
304 |
+
|
305 |
+
def encode_text(self, text, normalize: bool = False):
|
306 |
+
features = self.text(text)
|
307 |
+
return F.normalize(features, dim=-1) if normalize else features
|
308 |
+
|
309 |
+
def forward(self, image, text):
|
310 |
+
image_features = self.encode_image(image, normalize=True)
|
311 |
+
text_features = self.encode_text(text, normalize=True)
|
312 |
+
return image_features, text_features, self.logit_scale.exp()
|
313 |
+
|
314 |
+
|
315 |
+
def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
|
316 |
+
"""Convert applicable model parameters to low-precision (bf16 or fp16)"""
|
317 |
+
|
318 |
+
def _convert_weights(l):
|
319 |
+
|
320 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
321 |
+
l.weight.data = l.weight.data.to(dtype)
|
322 |
+
if l.bias is not None:
|
323 |
+
l.bias.data = l.bias.data.to(dtype)
|
324 |
+
|
325 |
+
if isinstance(l, (nn.MultiheadAttention, Attention)):
|
326 |
+
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
327 |
+
tensor = getattr(l, attr, None)
|
328 |
+
if tensor is not None:
|
329 |
+
tensor.data = tensor.data.to(dtype)
|
330 |
+
|
331 |
+
if isinstance(l, nn.Parameter):
|
332 |
+
l.data = l.data.to(dtype)
|
333 |
+
|
334 |
+
for name in ["text_projection", "proj"]:
|
335 |
+
if hasattr(l, name) and isinstance(l, nn.Parameter):
|
336 |
+
attr = getattr(l, name, None)
|
337 |
+
if attr is not None:
|
338 |
+
attr.data = attr.data.to(dtype)
|
339 |
+
|
340 |
+
model.apply(_convert_weights)
|
341 |
+
|
342 |
+
|
343 |
+
convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
|
344 |
+
|
345 |
+
|
346 |
+
# used to maintain checkpoint compatibility
|
347 |
+
def convert_to_custom_text_state_dict(state_dict: dict):
|
348 |
+
if 'text_projection' in state_dict:
|
349 |
+
# old format state_dict, move text tower -> .text
|
350 |
+
new_state_dict = {}
|
351 |
+
for k, v in state_dict.items():
|
352 |
+
if any(k.startswith(p) for p in (
|
353 |
+
'text_projection',
|
354 |
+
'positional_embedding',
|
355 |
+
'token_embedding',
|
356 |
+
'transformer',
|
357 |
+
'ln_final',
|
358 |
+
'logit_scale'
|
359 |
+
)):
|
360 |
+
k = 'text.' + k
|
361 |
+
new_state_dict[k] = v
|
362 |
+
return new_state_dict
|
363 |
+
return state_dict
|
364 |
+
|
365 |
+
|
366 |
+
def build_model_from_openai_state_dict(
|
367 |
+
state_dict: dict,
|
368 |
+
quick_gelu=True,
|
369 |
+
cast_dtype=torch.float16,
|
370 |
+
):
|
371 |
+
vit = "visual.proj" in state_dict
|
372 |
+
|
373 |
+
if vit:
|
374 |
+
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
375 |
+
vision_layers = len(
|
376 |
+
[k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
377 |
+
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
378 |
+
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
379 |
+
image_size = vision_patch_size * grid_size
|
380 |
+
else:
|
381 |
+
counts: list = [
|
382 |
+
len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
383 |
+
vision_layers = tuple(counts)
|
384 |
+
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
385 |
+
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
386 |
+
vision_patch_size = None
|
387 |
+
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
388 |
+
image_size = output_width * 32
|
389 |
+
|
390 |
+
embed_dim = state_dict["text_projection"].shape[1]
|
391 |
+
context_length = state_dict["positional_embedding"].shape[0]
|
392 |
+
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
393 |
+
transformer_width = state_dict["ln_final.weight"].shape[0]
|
394 |
+
transformer_heads = transformer_width // 64
|
395 |
+
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
|
396 |
+
|
397 |
+
vision_cfg = CLIPVisionCfg(
|
398 |
+
layers=vision_layers,
|
399 |
+
width=vision_width,
|
400 |
+
patch_size=vision_patch_size,
|
401 |
+
image_size=image_size,
|
402 |
+
)
|
403 |
+
text_cfg = CLIPTextCfg(
|
404 |
+
context_length=context_length,
|
405 |
+
vocab_size=vocab_size,
|
406 |
+
width=transformer_width,
|
407 |
+
heads=transformer_heads,
|
408 |
+
layers=transformer_layers
|
409 |
+
)
|
410 |
+
model = CLIP(
|
411 |
+
embed_dim,
|
412 |
+
vision_cfg=vision_cfg,
|
413 |
+
text_cfg=text_cfg,
|
414 |
+
quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
|
415 |
+
cast_dtype=cast_dtype,
|
416 |
+
)
|
417 |
+
|
418 |
+
for key in ["input_resolution", "context_length", "vocab_size"]:
|
419 |
+
state_dict.pop(key, None)
|
420 |
+
|
421 |
+
convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
|
422 |
+
model.load_state_dict(state_dict)
|
423 |
+
return model.eval()
|
424 |
+
|
425 |
+
|
426 |
+
def trace_model(model, batch_size=256, device=torch.device('cpu')):
|
427 |
+
model.eval()
|
428 |
+
image_size = model.visual.image_size
|
429 |
+
example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
|
430 |
+
example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
|
431 |
+
model = torch.jit.trace_module(
|
432 |
+
model,
|
433 |
+
inputs=dict(
|
434 |
+
forward=(example_images, example_text),
|
435 |
+
encode_text=(example_text,),
|
436 |
+
encode_image=(example_images,)
|
437 |
+
))
|
438 |
+
model.visual.image_size = image_size
|
439 |
+
return model
|
eva_clip/model_configs/EVA01-CLIP-B-16.json
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"patch_size": 16,
|
8 |
+
"eva_model_name": "eva-clip-b-16",
|
9 |
+
"ls_init_value": 0.1,
|
10 |
+
"drop_path_rate": 0.0
|
11 |
+
},
|
12 |
+
"text_cfg": {
|
13 |
+
"context_length": 77,
|
14 |
+
"vocab_size": 49408,
|
15 |
+
"width": 512,
|
16 |
+
"heads": 8,
|
17 |
+
"layers": 12
|
18 |
+
}
|
19 |
+
}
|
eva_clip/model_configs/EVA01-CLIP-g-14-plus.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 1024,
|
19 |
+
"heads": 16,
|
20 |
+
"layers": 24,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
eva_clip/model_configs/EVA01-CLIP-g-14.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 40,
|
6 |
+
"width": 1408,
|
7 |
+
"head_width": 88,
|
8 |
+
"mlp_ratio": 4.3637,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-g-14-x",
|
11 |
+
"drop_path_rate": 0.4,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true
|
14 |
+
},
|
15 |
+
"text_cfg": {
|
16 |
+
"context_length": 77,
|
17 |
+
"vocab_size": 49408,
|
18 |
+
"width": 768,
|
19 |
+
"heads": 12,
|
20 |
+
"layers": 12,
|
21 |
+
"xattn": false,
|
22 |
+
"fusedLN": true
|
23 |
+
}
|
24 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-B-16.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 512,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 12,
|
6 |
+
"width": 768,
|
7 |
+
"head_width": 64,
|
8 |
+
"patch_size": 16,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"eva_model_name": "eva-clip-b-16-X",
|
11 |
+
"drop_path_rate": 0.0,
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 512,
|
24 |
+
"heads": 8,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": true,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-L-14-336.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 336,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14-336",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-L-14.json
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 768,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 24,
|
6 |
+
"width": 1024,
|
7 |
+
"drop_path_rate": 0,
|
8 |
+
"head_width": 64,
|
9 |
+
"mlp_ratio": 2.6667,
|
10 |
+
"patch_size": 14,
|
11 |
+
"eva_model_name": "eva-clip-l-14",
|
12 |
+
"xattn": true,
|
13 |
+
"fusedLN": true,
|
14 |
+
"rope": true,
|
15 |
+
"pt_hw_seq_len": 16,
|
16 |
+
"intp_freq": true,
|
17 |
+
"naiveswiglu": true,
|
18 |
+
"subln": true
|
19 |
+
},
|
20 |
+
"text_cfg": {
|
21 |
+
"context_length": 77,
|
22 |
+
"vocab_size": 49408,
|
23 |
+
"width": 768,
|
24 |
+
"heads": 12,
|
25 |
+
"layers": 12,
|
26 |
+
"xattn": false,
|
27 |
+
"fusedLN": true
|
28 |
+
}
|
29 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1280,
|
20 |
+
"heads": 20,
|
21 |
+
"layers": 32,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|
eva_clip/model_configs/EVA02-CLIP-bigE-14.json
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"embed_dim": 1024,
|
3 |
+
"vision_cfg": {
|
4 |
+
"image_size": 224,
|
5 |
+
"layers": 64,
|
6 |
+
"width": 1792,
|
7 |
+
"head_width": 112,
|
8 |
+
"mlp_ratio": 8.571428571428571,
|
9 |
+
"patch_size": 14,
|
10 |
+
"eva_model_name": "eva-clip-4b-14-x",
|
11 |
+
"drop_path_rate": 0,
|
12 |
+
"xattn": true,
|
13 |
+
"postnorm": true,
|
14 |
+
"fusedLN": true
|
15 |
+
},
|
16 |
+
"text_cfg": {
|
17 |
+
"context_length": 77,
|
18 |
+
"vocab_size": 49408,
|
19 |
+
"width": 1024,
|
20 |
+
"heads": 16,
|
21 |
+
"layers": 24,
|
22 |
+
"xattn": false,
|
23 |
+
"fusedLN": true
|
24 |
+
}
|
25 |
+
}
|
eva_clip/modified_resnet.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from eva_clip.utils import freeze_batch_norm_2d
|
8 |
+
|
9 |
+
|
10 |
+
class Bottleneck(nn.Module):
|
11 |
+
expansion = 4
|
12 |
+
|
13 |
+
def __init__(self, inplanes, planes, stride=1):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
17 |
+
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
18 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
19 |
+
self.act1 = nn.ReLU(inplace=True)
|
20 |
+
|
21 |
+
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
22 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
23 |
+
self.act2 = nn.ReLU(inplace=True)
|
24 |
+
|
25 |
+
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
26 |
+
|
27 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
28 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
29 |
+
self.act3 = nn.ReLU(inplace=True)
|
30 |
+
|
31 |
+
self.downsample = None
|
32 |
+
self.stride = stride
|
33 |
+
|
34 |
+
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
35 |
+
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
36 |
+
self.downsample = nn.Sequential(OrderedDict([
|
37 |
+
("-1", nn.AvgPool2d(stride)),
|
38 |
+
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
39 |
+
("1", nn.BatchNorm2d(planes * self.expansion))
|
40 |
+
]))
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor):
|
43 |
+
identity = x
|
44 |
+
|
45 |
+
out = self.act1(self.bn1(self.conv1(x)))
|
46 |
+
out = self.act2(self.bn2(self.conv2(out)))
|
47 |
+
out = self.avgpool(out)
|
48 |
+
out = self.bn3(self.conv3(out))
|
49 |
+
|
50 |
+
if self.downsample is not None:
|
51 |
+
identity = self.downsample(x)
|
52 |
+
|
53 |
+
out += identity
|
54 |
+
out = self.act3(out)
|
55 |
+
return out
|
56 |
+
|
57 |
+
|
58 |
+
class AttentionPool2d(nn.Module):
|
59 |
+
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
60 |
+
super().__init__()
|
61 |
+
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
62 |
+
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
63 |
+
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
64 |
+
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
65 |
+
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
66 |
+
self.num_heads = num_heads
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
|
70 |
+
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
71 |
+
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
72 |
+
x, _ = F.multi_head_attention_forward(
|
73 |
+
query=x, key=x, value=x,
|
74 |
+
embed_dim_to_check=x.shape[-1],
|
75 |
+
num_heads=self.num_heads,
|
76 |
+
q_proj_weight=self.q_proj.weight,
|
77 |
+
k_proj_weight=self.k_proj.weight,
|
78 |
+
v_proj_weight=self.v_proj.weight,
|
79 |
+
in_proj_weight=None,
|
80 |
+
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
81 |
+
bias_k=None,
|
82 |
+
bias_v=None,
|
83 |
+
add_zero_attn=False,
|
84 |
+
dropout_p=0.,
|
85 |
+
out_proj_weight=self.c_proj.weight,
|
86 |
+
out_proj_bias=self.c_proj.bias,
|
87 |
+
use_separate_proj_weight=True,
|
88 |
+
training=self.training,
|
89 |
+
need_weights=False
|
90 |
+
)
|
91 |
+
|
92 |
+
return x[0]
|
93 |
+
|
94 |
+
|
95 |
+
class ModifiedResNet(nn.Module):
|
96 |
+
"""
|
97 |
+
A ResNet class that is similar to torchvision's but contains the following changes:
|
98 |
+
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
99 |
+
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
100 |
+
- The final pooling layer is a QKV attention instead of an average pool
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, layers, output_dim, heads, image_size=224, width=64):
|
104 |
+
super().__init__()
|
105 |
+
self.output_dim = output_dim
|
106 |
+
self.image_size = image_size
|
107 |
+
|
108 |
+
# the 3-layer stem
|
109 |
+
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
110 |
+
self.bn1 = nn.BatchNorm2d(width // 2)
|
111 |
+
self.act1 = nn.ReLU(inplace=True)
|
112 |
+
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
113 |
+
self.bn2 = nn.BatchNorm2d(width // 2)
|
114 |
+
self.act2 = nn.ReLU(inplace=True)
|
115 |
+
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
116 |
+
self.bn3 = nn.BatchNorm2d(width)
|
117 |
+
self.act3 = nn.ReLU(inplace=True)
|
118 |
+
self.avgpool = nn.AvgPool2d(2)
|
119 |
+
|
120 |
+
# residual layers
|
121 |
+
self._inplanes = width # this is a *mutable* variable used during construction
|
122 |
+
self.layer1 = self._make_layer(width, layers[0])
|
123 |
+
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
124 |
+
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
125 |
+
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
126 |
+
|
127 |
+
embed_dim = width * 32 # the ResNet feature dimension
|
128 |
+
self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
|
129 |
+
|
130 |
+
self.init_parameters()
|
131 |
+
|
132 |
+
def _make_layer(self, planes, blocks, stride=1):
|
133 |
+
layers = [Bottleneck(self._inplanes, planes, stride)]
|
134 |
+
|
135 |
+
self._inplanes = planes * Bottleneck.expansion
|
136 |
+
for _ in range(1, blocks):
|
137 |
+
layers.append(Bottleneck(self._inplanes, planes))
|
138 |
+
|
139 |
+
return nn.Sequential(*layers)
|
140 |
+
|
141 |
+
def init_parameters(self):
|
142 |
+
if self.attnpool is not None:
|
143 |
+
std = self.attnpool.c_proj.in_features ** -0.5
|
144 |
+
nn.init.normal_(self.attnpool.q_proj.weight, std=std)
|
145 |
+
nn.init.normal_(self.attnpool.k_proj.weight, std=std)
|
146 |
+
nn.init.normal_(self.attnpool.v_proj.weight, std=std)
|
147 |
+
nn.init.normal_(self.attnpool.c_proj.weight, std=std)
|
148 |
+
|
149 |
+
for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
|
150 |
+
for name, param in resnet_block.named_parameters():
|
151 |
+
if name.endswith("bn3.weight"):
|
152 |
+
nn.init.zeros_(param)
|
153 |
+
|
154 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
155 |
+
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
|
156 |
+
for param in self.parameters():
|
157 |
+
param.requires_grad = False
|
158 |
+
if freeze_bn_stats:
|
159 |
+
freeze_batch_norm_2d(self)
|
160 |
+
|
161 |
+
@torch.jit.ignore
|
162 |
+
def set_grad_checkpointing(self, enable=True):
|
163 |
+
# FIXME support for non-transformer
|
164 |
+
pass
|
165 |
+
|
166 |
+
def stem(self, x):
|
167 |
+
x = self.act1(self.bn1(self.conv1(x)))
|
168 |
+
x = self.act2(self.bn2(self.conv2(x)))
|
169 |
+
x = self.act3(self.bn3(self.conv3(x)))
|
170 |
+
x = self.avgpool(x)
|
171 |
+
return x
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
x = self.stem(x)
|
175 |
+
x = self.layer1(x)
|
176 |
+
x = self.layer2(x)
|
177 |
+
x = self.layer3(x)
|
178 |
+
x = self.layer4(x)
|
179 |
+
x = self.attnpool(x)
|
180 |
+
|
181 |
+
return x
|
eva_clip/openai.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" OpenAI pretrained model functions
|
2 |
+
|
3 |
+
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from typing import List, Optional, Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
|
13 |
+
from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
|
14 |
+
|
15 |
+
__all__ = ["list_openai_models", "load_openai_model"]
|
16 |
+
|
17 |
+
|
18 |
+
def list_openai_models() -> List[str]:
|
19 |
+
"""Returns the names of available CLIP models"""
|
20 |
+
return list_pretrained_models_by_tag('openai')
|
21 |
+
|
22 |
+
|
23 |
+
def load_openai_model(
|
24 |
+
name: str,
|
25 |
+
precision: Optional[str] = None,
|
26 |
+
device: Optional[Union[str, torch.device]] = None,
|
27 |
+
jit: bool = True,
|
28 |
+
cache_dir: Optional[str] = None,
|
29 |
+
):
|
30 |
+
"""Load a CLIP model
|
31 |
+
|
32 |
+
Parameters
|
33 |
+
----------
|
34 |
+
name : str
|
35 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
36 |
+
precision: str
|
37 |
+
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
|
38 |
+
device : Union[str, torch.device]
|
39 |
+
The device to put the loaded model
|
40 |
+
jit : bool
|
41 |
+
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
|
42 |
+
cache_dir : Optional[str]
|
43 |
+
The directory to cache the downloaded model weights
|
44 |
+
|
45 |
+
Returns
|
46 |
+
-------
|
47 |
+
model : torch.nn.Module
|
48 |
+
The CLIP model
|
49 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
50 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
51 |
+
"""
|
52 |
+
if device is None:
|
53 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
54 |
+
if precision is None:
|
55 |
+
precision = 'fp32' if device == 'cpu' else 'fp16'
|
56 |
+
|
57 |
+
if get_pretrained_url(name, 'openai'):
|
58 |
+
model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
|
59 |
+
elif os.path.isfile(name):
|
60 |
+
model_path = name
|
61 |
+
else:
|
62 |
+
raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
|
63 |
+
|
64 |
+
try:
|
65 |
+
# loading JIT archive
|
66 |
+
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
|
67 |
+
state_dict = None
|
68 |
+
except RuntimeError:
|
69 |
+
# loading saved state dict
|
70 |
+
if jit:
|
71 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
72 |
+
jit = False
|
73 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
74 |
+
|
75 |
+
if not jit:
|
76 |
+
# Build a non-jit model from the OpenAI jitted model state dict
|
77 |
+
cast_dtype = get_cast_dtype(precision)
|
78 |
+
try:
|
79 |
+
model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
|
80 |
+
except KeyError:
|
81 |
+
sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
|
82 |
+
model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
|
83 |
+
|
84 |
+
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
|
85 |
+
model = model.to(device)
|
86 |
+
if precision.startswith('amp') or precision == 'fp32':
|
87 |
+
model.float()
|
88 |
+
elif precision == 'bf16':
|
89 |
+
convert_weights_to_lp(model, dtype=torch.bfloat16)
|
90 |
+
|
91 |
+
return model
|
92 |
+
|
93 |
+
# patch the device names
|
94 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
95 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
96 |
+
|
97 |
+
def patch_device(module):
|
98 |
+
try:
|
99 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
100 |
+
except RuntimeError:
|
101 |
+
graphs = []
|
102 |
+
|
103 |
+
if hasattr(module, "forward1"):
|
104 |
+
graphs.append(module.forward1.graph)
|
105 |
+
|
106 |
+
for graph in graphs:
|
107 |
+
for node in graph.findAllNodes("prim::Constant"):
|
108 |
+
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
109 |
+
node.copyAttributes(device_node)
|
110 |
+
|
111 |
+
model.apply(patch_device)
|
112 |
+
patch_device(model.encode_image)
|
113 |
+
patch_device(model.encode_text)
|
114 |
+
|
115 |
+
# patch dtype to float32 (typically for CPU)
|
116 |
+
if precision == 'fp32':
|
117 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
118 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
119 |
+
float_node = float_input.node()
|
120 |
+
|
121 |
+
def patch_float(module):
|
122 |
+
try:
|
123 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
124 |
+
except RuntimeError:
|
125 |
+
graphs = []
|
126 |
+
|
127 |
+
if hasattr(module, "forward1"):
|
128 |
+
graphs.append(module.forward1.graph)
|
129 |
+
|
130 |
+
for graph in graphs:
|
131 |
+
for node in graph.findAllNodes("aten::to"):
|
132 |
+
inputs = list(node.inputs())
|
133 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
134 |
+
if inputs[i].node()["value"] == 5:
|
135 |
+
inputs[i].node().copyAttributes(float_node)
|
136 |
+
|
137 |
+
model.apply(patch_float)
|
138 |
+
patch_float(model.encode_image)
|
139 |
+
patch_float(model.encode_text)
|
140 |
+
model.float()
|
141 |
+
|
142 |
+
# ensure image_size attr available at consistent location for both jit and non-jit
|
143 |
+
model.visual.image_size = model.input_resolution.item()
|
144 |
+
return model
|
eva_clip/pretrained.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import os
|
3 |
+
import urllib
|
4 |
+
import warnings
|
5 |
+
from functools import partial
|
6 |
+
from typing import Dict, Union
|
7 |
+
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
try:
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
_has_hf_hub = True
|
13 |
+
except ImportError:
|
14 |
+
hf_hub_download = None
|
15 |
+
_has_hf_hub = False
|
16 |
+
|
17 |
+
|
18 |
+
def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
|
19 |
+
return dict(
|
20 |
+
url=url,
|
21 |
+
hf_hub=hf_hub,
|
22 |
+
mean=mean,
|
23 |
+
std=std,
|
24 |
+
)
|
25 |
+
|
26 |
+
_VITB32 = dict(
|
27 |
+
openai=_pcfg(
|
28 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
29 |
+
laion400m_e31=_pcfg(
|
30 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
31 |
+
laion400m_e32=_pcfg(
|
32 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
33 |
+
laion2b_e16=_pcfg(
|
34 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
|
35 |
+
laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
|
36 |
+
)
|
37 |
+
|
38 |
+
_VITB32_quickgelu = dict(
|
39 |
+
openai=_pcfg(
|
40 |
+
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
|
41 |
+
laion400m_e31=_pcfg(
|
42 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
|
43 |
+
laion400m_e32=_pcfg(
|
44 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
|
45 |
+
)
|
46 |
+
|
47 |
+
_VITB16 = dict(
|
48 |
+
openai=_pcfg(
|
49 |
+
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
|
50 |
+
laion400m_e31=_pcfg(
|
51 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
|
52 |
+
laion400m_e32=_pcfg(
|
53 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
|
54 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
|
55 |
+
)
|
56 |
+
|
57 |
+
_EVAB16 = dict(
|
58 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
|
59 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
|
60 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
|
61 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
|
62 |
+
)
|
63 |
+
|
64 |
+
_VITB16_PLUS_240 = dict(
|
65 |
+
laion400m_e31=_pcfg(
|
66 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
|
67 |
+
laion400m_e32=_pcfg(
|
68 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
|
69 |
+
)
|
70 |
+
|
71 |
+
_VITL14 = dict(
|
72 |
+
openai=_pcfg(
|
73 |
+
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
|
74 |
+
laion400m_e31=_pcfg(
|
75 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
|
76 |
+
laion400m_e32=_pcfg(
|
77 |
+
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
|
78 |
+
laion2b_s32b_b82k=_pcfg(
|
79 |
+
hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
|
80 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
81 |
+
)
|
82 |
+
|
83 |
+
_EVAL14 = dict(
|
84 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
|
85 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
|
86 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
|
87 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
|
88 |
+
)
|
89 |
+
|
90 |
+
_VITL14_336 = dict(
|
91 |
+
openai=_pcfg(
|
92 |
+
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
|
93 |
+
)
|
94 |
+
|
95 |
+
_EVAL14_336 = dict(
|
96 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
|
97 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
|
98 |
+
eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
|
99 |
+
eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
|
100 |
+
)
|
101 |
+
|
102 |
+
_VITH14 = dict(
|
103 |
+
laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
|
104 |
+
)
|
105 |
+
|
106 |
+
_VITg14 = dict(
|
107 |
+
laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
|
108 |
+
laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
|
109 |
+
)
|
110 |
+
|
111 |
+
_EVAg14 = dict(
|
112 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
|
113 |
+
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
|
114 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
|
115 |
+
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
|
116 |
+
)
|
117 |
+
|
118 |
+
_EVAg14_PLUS = dict(
|
119 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
|
120 |
+
eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
|
121 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
|
122 |
+
eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
|
123 |
+
)
|
124 |
+
|
125 |
+
_VITbigG14 = dict(
|
126 |
+
laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
|
127 |
+
)
|
128 |
+
|
129 |
+
_EVAbigE14 = dict(
|
130 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
131 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
132 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
|
133 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
|
134 |
+
)
|
135 |
+
|
136 |
+
_EVAbigE14_PLUS = dict(
|
137 |
+
eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
138 |
+
eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
|
139 |
+
eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
|
140 |
+
eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
_PRETRAINED = {
|
145 |
+
# "ViT-B-32": _VITB32,
|
146 |
+
"OpenaiCLIP-B-32": _VITB32,
|
147 |
+
"OpenCLIP-B-32": _VITB32,
|
148 |
+
|
149 |
+
# "ViT-B-32-quickgelu": _VITB32_quickgelu,
|
150 |
+
"OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
151 |
+
"OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
|
152 |
+
|
153 |
+
# "ViT-B-16": _VITB16,
|
154 |
+
"OpenaiCLIP-B-16": _VITB16,
|
155 |
+
"OpenCLIP-B-16": _VITB16,
|
156 |
+
|
157 |
+
"EVA02-B-16": _EVAB16,
|
158 |
+
"EVA02-CLIP-B-16": _EVAB16,
|
159 |
+
|
160 |
+
# "ViT-B-16-plus-240": _VITB16_PLUS_240,
|
161 |
+
"OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
|
162 |
+
|
163 |
+
# "ViT-L-14": _VITL14,
|
164 |
+
"OpenaiCLIP-L-14": _VITL14,
|
165 |
+
"OpenCLIP-L-14": _VITL14,
|
166 |
+
|
167 |
+
"EVA02-L-14": _EVAL14,
|
168 |
+
"EVA02-CLIP-L-14": _EVAL14,
|
169 |
+
|
170 |
+
# "ViT-L-14-336": _VITL14_336,
|
171 |
+
"OpenaiCLIP-L-14-336": _VITL14_336,
|
172 |
+
|
173 |
+
"EVA02-CLIP-L-14-336": _EVAL14_336,
|
174 |
+
|
175 |
+
# "ViT-H-14": _VITH14,
|
176 |
+
# "ViT-g-14": _VITg14,
|
177 |
+
"OpenCLIP-H-14": _VITH14,
|
178 |
+
"OpenCLIP-g-14": _VITg14,
|
179 |
+
|
180 |
+
"EVA01-CLIP-g-14": _EVAg14,
|
181 |
+
"EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
|
182 |
+
|
183 |
+
# "ViT-bigG-14": _VITbigG14,
|
184 |
+
"OpenCLIP-bigG-14": _VITbigG14,
|
185 |
+
|
186 |
+
"EVA02-CLIP-bigE-14": _EVAbigE14,
|
187 |
+
"EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
def _clean_tag(tag: str):
|
192 |
+
# normalize pretrained tags
|
193 |
+
return tag.lower().replace('-', '_')
|
194 |
+
|
195 |
+
|
196 |
+
def list_pretrained(as_str: bool = False):
|
197 |
+
""" returns list of pretrained models
|
198 |
+
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
|
199 |
+
"""
|
200 |
+
return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
|
201 |
+
|
202 |
+
|
203 |
+
def list_pretrained_models_by_tag(tag: str):
|
204 |
+
""" return all models having the specified pretrain tag """
|
205 |
+
models = []
|
206 |
+
tag = _clean_tag(tag)
|
207 |
+
for k in _PRETRAINED.keys():
|
208 |
+
if tag in _PRETRAINED[k]:
|
209 |
+
models.append(k)
|
210 |
+
return models
|
211 |
+
|
212 |
+
|
213 |
+
def list_pretrained_tags_by_model(model: str):
|
214 |
+
""" return all pretrain tags for the specified model architecture """
|
215 |
+
tags = []
|
216 |
+
if model in _PRETRAINED:
|
217 |
+
tags.extend(_PRETRAINED[model].keys())
|
218 |
+
return tags
|
219 |
+
|
220 |
+
|
221 |
+
def is_pretrained_cfg(model: str, tag: str):
|
222 |
+
if model not in _PRETRAINED:
|
223 |
+
return False
|
224 |
+
return _clean_tag(tag) in _PRETRAINED[model]
|
225 |
+
|
226 |
+
|
227 |
+
def get_pretrained_cfg(model: str, tag: str):
|
228 |
+
if model not in _PRETRAINED:
|
229 |
+
return {}
|
230 |
+
model_pretrained = _PRETRAINED[model]
|
231 |
+
return model_pretrained.get(_clean_tag(tag), {})
|
232 |
+
|
233 |
+
|
234 |
+
def get_pretrained_url(model: str, tag: str):
|
235 |
+
cfg = get_pretrained_cfg(model, _clean_tag(tag))
|
236 |
+
return cfg.get('url', '')
|
237 |
+
|
238 |
+
|
239 |
+
def download_pretrained_from_url(
|
240 |
+
url: str,
|
241 |
+
cache_dir: Union[str, None] = None,
|
242 |
+
):
|
243 |
+
if not cache_dir:
|
244 |
+
cache_dir = os.path.expanduser("~/.cache/clip")
|
245 |
+
os.makedirs(cache_dir, exist_ok=True)
|
246 |
+
filename = os.path.basename(url)
|
247 |
+
|
248 |
+
if 'openaipublic' in url:
|
249 |
+
expected_sha256 = url.split("/")[-2]
|
250 |
+
elif 'mlfoundations' in url:
|
251 |
+
expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
|
252 |
+
else:
|
253 |
+
expected_sha256 = ''
|
254 |
+
|
255 |
+
download_target = os.path.join(cache_dir, filename)
|
256 |
+
|
257 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
258 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
259 |
+
|
260 |
+
if os.path.isfile(download_target):
|
261 |
+
if expected_sha256:
|
262 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
263 |
+
return download_target
|
264 |
+
else:
|
265 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
266 |
+
else:
|
267 |
+
return download_target
|
268 |
+
|
269 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
270 |
+
with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
|
271 |
+
while True:
|
272 |
+
buffer = source.read(8192)
|
273 |
+
if not buffer:
|
274 |
+
break
|
275 |
+
|
276 |
+
output.write(buffer)
|
277 |
+
loop.update(len(buffer))
|
278 |
+
|
279 |
+
if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
|
280 |
+
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
|
281 |
+
|
282 |
+
return download_target
|
283 |
+
|
284 |
+
|
285 |
+
def has_hf_hub(necessary=False):
|
286 |
+
if not _has_hf_hub and necessary:
|
287 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
288 |
+
raise RuntimeError(
|
289 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
290 |
+
return _has_hf_hub
|
291 |
+
|
292 |
+
|
293 |
+
def download_pretrained_from_hf(
|
294 |
+
model_id: str,
|
295 |
+
filename: str = 'open_clip_pytorch_model.bin',
|
296 |
+
revision=None,
|
297 |
+
cache_dir: Union[str, None] = None,
|
298 |
+
):
|
299 |
+
has_hf_hub(True)
|
300 |
+
cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
|
301 |
+
return cached_file
|
302 |
+
|
303 |
+
|
304 |
+
def download_pretrained(
|
305 |
+
cfg: Dict,
|
306 |
+
force_hf_hub: bool = False,
|
307 |
+
cache_dir: Union[str, None] = None,
|
308 |
+
):
|
309 |
+
target = ''
|
310 |
+
if not cfg:
|
311 |
+
return target
|
312 |
+
|
313 |
+
download_url = cfg.get('url', '')
|
314 |
+
download_hf_hub = cfg.get('hf_hub', '')
|
315 |
+
if download_hf_hub and force_hf_hub:
|
316 |
+
# use HF hub even if url exists
|
317 |
+
download_url = ''
|
318 |
+
|
319 |
+
if download_url:
|
320 |
+
target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
|
321 |
+
elif download_hf_hub:
|
322 |
+
has_hf_hub(True)
|
323 |
+
# we assume the hf_hub entries in pretrained config combine model_id + filename in
|
324 |
+
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
|
325 |
+
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
|
326 |
+
model_id, filename = os.path.split(download_hf_hub)
|
327 |
+
if filename:
|
328 |
+
target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
|
329 |
+
else:
|
330 |
+
target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
|
331 |
+
|
332 |
+
return target
|
eva_clip/rope.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import pi
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from einops import rearrange, repeat
|
5 |
+
import logging
|
6 |
+
|
7 |
+
def broadcat(tensors, dim = -1):
|
8 |
+
num_tensors = len(tensors)
|
9 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
10 |
+
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
|
11 |
+
shape_len = list(shape_lens)[0]
|
12 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
13 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
14 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
15 |
+
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
|
16 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
17 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
18 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
19 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
20 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
21 |
+
return torch.cat(tensors, dim = dim)
|
22 |
+
|
23 |
+
def rotate_half(x):
|
24 |
+
x = rearrange(x, '... (d r) -> ... d r', r = 2)
|
25 |
+
x1, x2 = x.unbind(dim = -1)
|
26 |
+
x = torch.stack((-x2, x1), dim = -1)
|
27 |
+
return rearrange(x, '... d r -> ... (d r)')
|
28 |
+
|
29 |
+
|
30 |
+
class VisionRotaryEmbedding(nn.Module):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
dim,
|
34 |
+
pt_seq_len,
|
35 |
+
ft_seq_len=None,
|
36 |
+
custom_freqs = None,
|
37 |
+
freqs_for = 'lang',
|
38 |
+
theta = 10000,
|
39 |
+
max_freq = 10,
|
40 |
+
num_freqs = 1,
|
41 |
+
):
|
42 |
+
super().__init__()
|
43 |
+
if custom_freqs:
|
44 |
+
freqs = custom_freqs
|
45 |
+
elif freqs_for == 'lang':
|
46 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
47 |
+
elif freqs_for == 'pixel':
|
48 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
49 |
+
elif freqs_for == 'constant':
|
50 |
+
freqs = torch.ones(num_freqs).float()
|
51 |
+
else:
|
52 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
53 |
+
|
54 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
55 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
56 |
+
|
57 |
+
freqs_h = torch.einsum('..., f -> ... f', t, freqs)
|
58 |
+
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
|
59 |
+
|
60 |
+
freqs_w = torch.einsum('..., f -> ... f', t, freqs)
|
61 |
+
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
|
62 |
+
|
63 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
|
64 |
+
|
65 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
66 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
67 |
+
|
68 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
69 |
+
|
70 |
+
def forward(self, t, start_index = 0):
|
71 |
+
rot_dim = self.freqs_cos.shape[-1]
|
72 |
+
end_index = start_index + rot_dim
|
73 |
+
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
74 |
+
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
|
75 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
76 |
+
|
77 |
+
return torch.cat((t_left, t, t_right), dim = -1)
|
78 |
+
|
79 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
dim,
|
83 |
+
pt_seq_len,
|
84 |
+
ft_seq_len=None,
|
85 |
+
custom_freqs = None,
|
86 |
+
freqs_for = 'lang',
|
87 |
+
theta = 10000,
|
88 |
+
max_freq = 10,
|
89 |
+
num_freqs = 1,
|
90 |
+
patch_dropout = 0.
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
if custom_freqs:
|
94 |
+
freqs = custom_freqs
|
95 |
+
elif freqs_for == 'lang':
|
96 |
+
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
|
97 |
+
elif freqs_for == 'pixel':
|
98 |
+
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
|
99 |
+
elif freqs_for == 'constant':
|
100 |
+
freqs = torch.ones(num_freqs).float()
|
101 |
+
else:
|
102 |
+
raise ValueError(f'unknown modality {freqs_for}')
|
103 |
+
|
104 |
+
if ft_seq_len is None: ft_seq_len = pt_seq_len
|
105 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
106 |
+
|
107 |
+
freqs = torch.einsum('..., f -> ... f', t, freqs)
|
108 |
+
freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
|
109 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
|
110 |
+
|
111 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
112 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
113 |
+
|
114 |
+
self.patch_dropout = patch_dropout
|
115 |
+
|
116 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
117 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
118 |
+
|
119 |
+
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
120 |
+
|
121 |
+
def forward(self, t, patch_indices_keep=None):
|
122 |
+
if patch_indices_keep is not None:
|
123 |
+
batch = t.size()[0]
|
124 |
+
batch_indices = torch.arange(batch)
|
125 |
+
batch_indices = batch_indices[..., None]
|
126 |
+
|
127 |
+
freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
128 |
+
freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
|
129 |
+
|
130 |
+
freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
|
131 |
+
freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
|
132 |
+
freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
|
133 |
+
freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
|
134 |
+
|
135 |
+
return t * freqs_cos + rotate_half(t) * freqs_sin
|
136 |
+
|
137 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
eva_clip/timm_model.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" timm model adapter
|
2 |
+
|
3 |
+
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
|
4 |
+
"""
|
5 |
+
import logging
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
try:
|
12 |
+
import timm
|
13 |
+
from timm.models.layers import Mlp, to_2tuple
|
14 |
+
try:
|
15 |
+
# old timm imports < 0.8.1
|
16 |
+
from timm.models.layers.attention_pool2d import RotAttentionPool2d
|
17 |
+
from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
|
18 |
+
except ImportError:
|
19 |
+
# new timm imports >= 0.8.1
|
20 |
+
from timm.layers import RotAttentionPool2d
|
21 |
+
from timm.layers import AttentionPool2d as AbsAttentionPool2d
|
22 |
+
except ImportError:
|
23 |
+
timm = None
|
24 |
+
|
25 |
+
from .utils import freeze_batch_norm_2d
|
26 |
+
|
27 |
+
|
28 |
+
class TimmModel(nn.Module):
|
29 |
+
""" timm model adapter
|
30 |
+
# FIXME this adapter is a work in progress, may change in ways that break weight compat
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
model_name,
|
36 |
+
embed_dim,
|
37 |
+
image_size=224,
|
38 |
+
pool='avg',
|
39 |
+
proj='linear',
|
40 |
+
proj_bias=False,
|
41 |
+
drop=0.,
|
42 |
+
pretrained=False):
|
43 |
+
super().__init__()
|
44 |
+
if timm is None:
|
45 |
+
raise RuntimeError("Please `pip install timm` to use timm models.")
|
46 |
+
|
47 |
+
self.image_size = to_2tuple(image_size)
|
48 |
+
self.trunk = timm.create_model(model_name, pretrained=pretrained)
|
49 |
+
feat_size = self.trunk.default_cfg.get('pool_size', None)
|
50 |
+
feature_ndim = 1 if not feat_size else 2
|
51 |
+
if pool in ('abs_attn', 'rot_attn'):
|
52 |
+
assert feature_ndim == 2
|
53 |
+
# if attn pooling used, remove both classifier and default pool
|
54 |
+
self.trunk.reset_classifier(0, global_pool='')
|
55 |
+
else:
|
56 |
+
# reset global pool if pool config set, otherwise leave as network default
|
57 |
+
reset_kwargs = dict(global_pool=pool) if pool else {}
|
58 |
+
self.trunk.reset_classifier(0, **reset_kwargs)
|
59 |
+
prev_chs = self.trunk.num_features
|
60 |
+
|
61 |
+
head_layers = OrderedDict()
|
62 |
+
if pool == 'abs_attn':
|
63 |
+
head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
|
64 |
+
prev_chs = embed_dim
|
65 |
+
elif pool == 'rot_attn':
|
66 |
+
head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
|
67 |
+
prev_chs = embed_dim
|
68 |
+
else:
|
69 |
+
assert proj, 'projection layer needed if non-attention pooling is used.'
|
70 |
+
|
71 |
+
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
|
72 |
+
if proj == 'linear':
|
73 |
+
head_layers['drop'] = nn.Dropout(drop)
|
74 |
+
head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
|
75 |
+
elif proj == 'mlp':
|
76 |
+
head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
|
77 |
+
|
78 |
+
self.head = nn.Sequential(head_layers)
|
79 |
+
|
80 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
81 |
+
""" lock modules
|
82 |
+
Args:
|
83 |
+
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
|
84 |
+
"""
|
85 |
+
if not unlocked_groups:
|
86 |
+
# lock full model
|
87 |
+
for param in self.trunk.parameters():
|
88 |
+
param.requires_grad = False
|
89 |
+
if freeze_bn_stats:
|
90 |
+
freeze_batch_norm_2d(self.trunk)
|
91 |
+
else:
|
92 |
+
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
|
93 |
+
try:
|
94 |
+
# FIXME import here until API stable and in an official release
|
95 |
+
from timm.models.helpers import group_parameters, group_modules
|
96 |
+
except ImportError:
|
97 |
+
raise RuntimeError(
|
98 |
+
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
|
99 |
+
matcher = self.trunk.group_matcher()
|
100 |
+
gparams = group_parameters(self.trunk, matcher)
|
101 |
+
max_layer_id = max(gparams.keys())
|
102 |
+
max_layer_id = max_layer_id - unlocked_groups
|
103 |
+
for group_idx in range(max_layer_id + 1):
|
104 |
+
group = gparams[group_idx]
|
105 |
+
for param in group:
|
106 |
+
self.trunk.get_parameter(param).requires_grad = False
|
107 |
+
if freeze_bn_stats:
|
108 |
+
gmodules = group_modules(self.trunk, matcher, reverse=True)
|
109 |
+
gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
|
110 |
+
freeze_batch_norm_2d(self.trunk, gmodules)
|
111 |
+
|
112 |
+
@torch.jit.ignore
|
113 |
+
def set_grad_checkpointing(self, enable=True):
|
114 |
+
try:
|
115 |
+
self.trunk.set_grad_checkpointing(enable)
|
116 |
+
except Exception as e:
|
117 |
+
logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
|
118 |
+
|
119 |
+
def forward(self, x):
|
120 |
+
x = self.trunk(x)
|
121 |
+
x = self.head(x)
|
122 |
+
return x
|
eva_clip/tokenizer.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" CLIP tokenizer
|
2 |
+
|
3 |
+
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
|
4 |
+
"""
|
5 |
+
import gzip
|
6 |
+
import html
|
7 |
+
import os
|
8 |
+
from functools import lru_cache
|
9 |
+
from typing import Union, List
|
10 |
+
|
11 |
+
import ftfy
|
12 |
+
import regex as re
|
13 |
+
import torch
|
14 |
+
|
15 |
+
# https://stackoverflow.com/q/62691279
|
16 |
+
import os
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
+
|
19 |
+
|
20 |
+
@lru_cache()
|
21 |
+
def default_bpe():
|
22 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
23 |
+
|
24 |
+
|
25 |
+
@lru_cache()
|
26 |
+
def bytes_to_unicode():
|
27 |
+
"""
|
28 |
+
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
29 |
+
The reversible bpe codes work on unicode strings.
|
30 |
+
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
31 |
+
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
32 |
+
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
33 |
+
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
34 |
+
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
35 |
+
"""
|
36 |
+
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
37 |
+
cs = bs[:]
|
38 |
+
n = 0
|
39 |
+
for b in range(2**8):
|
40 |
+
if b not in bs:
|
41 |
+
bs.append(b)
|
42 |
+
cs.append(2**8+n)
|
43 |
+
n += 1
|
44 |
+
cs = [chr(n) for n in cs]
|
45 |
+
return dict(zip(bs, cs))
|
46 |
+
|
47 |
+
|
48 |
+
def get_pairs(word):
|
49 |
+
"""Return set of symbol pairs in a word.
|
50 |
+
Word is represented as tuple of symbols (symbols being variable-length strings).
|
51 |
+
"""
|
52 |
+
pairs = set()
|
53 |
+
prev_char = word[0]
|
54 |
+
for char in word[1:]:
|
55 |
+
pairs.add((prev_char, char))
|
56 |
+
prev_char = char
|
57 |
+
return pairs
|
58 |
+
|
59 |
+
|
60 |
+
def basic_clean(text):
|
61 |
+
text = ftfy.fix_text(text)
|
62 |
+
text = html.unescape(html.unescape(text))
|
63 |
+
return text.strip()
|
64 |
+
|
65 |
+
|
66 |
+
def whitespace_clean(text):
|
67 |
+
text = re.sub(r'\s+', ' ', text)
|
68 |
+
text = text.strip()
|
69 |
+
return text
|
70 |
+
|
71 |
+
|
72 |
+
class SimpleTokenizer(object):
|
73 |
+
def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
|
74 |
+
self.byte_encoder = bytes_to_unicode()
|
75 |
+
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
76 |
+
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
77 |
+
merges = merges[1:49152-256-2+1]
|
78 |
+
merges = [tuple(merge.split()) for merge in merges]
|
79 |
+
vocab = list(bytes_to_unicode().values())
|
80 |
+
vocab = vocab + [v+'</w>' for v in vocab]
|
81 |
+
for merge in merges:
|
82 |
+
vocab.append(''.join(merge))
|
83 |
+
if not special_tokens:
|
84 |
+
special_tokens = ['<start_of_text>', '<end_of_text>']
|
85 |
+
else:
|
86 |
+
special_tokens = ['<start_of_text>', '<end_of_text>'] + special_tokens
|
87 |
+
vocab.extend(special_tokens)
|
88 |
+
self.encoder = dict(zip(vocab, range(len(vocab))))
|
89 |
+
self.decoder = {v: k for k, v in self.encoder.items()}
|
90 |
+
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
91 |
+
self.cache = {t:t for t in special_tokens}
|
92 |
+
special = "|".join(special_tokens)
|
93 |
+
self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
94 |
+
|
95 |
+
self.vocab_size = len(self.encoder)
|
96 |
+
self.all_special_ids = [self.encoder[t] for t in special_tokens]
|
97 |
+
|
98 |
+
def bpe(self, token):
|
99 |
+
if token in self.cache:
|
100 |
+
return self.cache[token]
|
101 |
+
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
102 |
+
pairs = get_pairs(word)
|
103 |
+
|
104 |
+
if not pairs:
|
105 |
+
return token+'</w>'
|
106 |
+
|
107 |
+
while True:
|
108 |
+
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
109 |
+
if bigram not in self.bpe_ranks:
|
110 |
+
break
|
111 |
+
first, second = bigram
|
112 |
+
new_word = []
|
113 |
+
i = 0
|
114 |
+
while i < len(word):
|
115 |
+
try:
|
116 |
+
j = word.index(first, i)
|
117 |
+
new_word.extend(word[i:j])
|
118 |
+
i = j
|
119 |
+
except:
|
120 |
+
new_word.extend(word[i:])
|
121 |
+
break
|
122 |
+
|
123 |
+
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
124 |
+
new_word.append(first+second)
|
125 |
+
i += 2
|
126 |
+
else:
|
127 |
+
new_word.append(word[i])
|
128 |
+
i += 1
|
129 |
+
new_word = tuple(new_word)
|
130 |
+
word = new_word
|
131 |
+
if len(word) == 1:
|
132 |
+
break
|
133 |
+
else:
|
134 |
+
pairs = get_pairs(word)
|
135 |
+
word = ' '.join(word)
|
136 |
+
self.cache[token] = word
|
137 |
+
return word
|
138 |
+
|
139 |
+
def encode(self, text):
|
140 |
+
bpe_tokens = []
|
141 |
+
text = whitespace_clean(basic_clean(text)).lower()
|
142 |
+
for token in re.findall(self.pat, text):
|
143 |
+
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
144 |
+
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
145 |
+
return bpe_tokens
|
146 |
+
|
147 |
+
def decode(self, tokens):
|
148 |
+
text = ''.join([self.decoder[token] for token in tokens])
|
149 |
+
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
150 |
+
return text
|
151 |
+
|
152 |
+
|
153 |
+
_tokenizer = SimpleTokenizer()
|
154 |
+
|
155 |
+
|
156 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
|
157 |
+
"""
|
158 |
+
Returns the tokenized representation of given input string(s)
|
159 |
+
|
160 |
+
Parameters
|
161 |
+
----------
|
162 |
+
texts : Union[str, List[str]]
|
163 |
+
An input string or a list of input strings to tokenize
|
164 |
+
context_length : int
|
165 |
+
The context length to use; all CLIP models use 77 as the context length
|
166 |
+
|
167 |
+
Returns
|
168 |
+
-------
|
169 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
|
170 |
+
"""
|
171 |
+
if isinstance(texts, str):
|
172 |
+
texts = [texts]
|
173 |
+
|
174 |
+
sot_token = _tokenizer.encoder["<start_of_text>"]
|
175 |
+
eot_token = _tokenizer.encoder["<end_of_text>"]
|
176 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
177 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
178 |
+
|
179 |
+
for i, tokens in enumerate(all_tokens):
|
180 |
+
if len(tokens) > context_length:
|
181 |
+
tokens = tokens[:context_length] # Truncate
|
182 |
+
tokens[-1] = eot_token
|
183 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
184 |
+
|
185 |
+
return result
|
186 |
+
|
187 |
+
|
188 |
+
class HFTokenizer:
|
189 |
+
"HuggingFace tokenizer wrapper"
|
190 |
+
def __init__(self, tokenizer_name:str):
|
191 |
+
from transformers import AutoTokenizer
|
192 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
193 |
+
|
194 |
+
def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
|
195 |
+
# same cleaning as for default tokenizer, except lowercasing
|
196 |
+
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
|
197 |
+
if isinstance(texts, str):
|
198 |
+
texts = [texts]
|
199 |
+
texts = [whitespace_clean(basic_clean(text)) for text in texts]
|
200 |
+
input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
|
201 |
+
return input_ids
|
eva_clip/transform.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torchvision.transforms.functional as F
|
6 |
+
|
7 |
+
from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
|
8 |
+
CenterCrop
|
9 |
+
|
10 |
+
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
|
11 |
+
|
12 |
+
|
13 |
+
class ResizeMaxSize(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
|
16 |
+
super().__init__()
|
17 |
+
if not isinstance(max_size, int):
|
18 |
+
raise TypeError(f"Size should be int. Got {type(max_size)}")
|
19 |
+
self.max_size = max_size
|
20 |
+
self.interpolation = interpolation
|
21 |
+
self.fn = min if fn == 'min' else min
|
22 |
+
self.fill = fill
|
23 |
+
|
24 |
+
def forward(self, img):
|
25 |
+
if isinstance(img, torch.Tensor):
|
26 |
+
height, width = img.shape[:2]
|
27 |
+
else:
|
28 |
+
width, height = img.size
|
29 |
+
scale = self.max_size / float(max(height, width))
|
30 |
+
if scale != 1.0:
|
31 |
+
new_size = tuple(round(dim * scale) for dim in (height, width))
|
32 |
+
img = F.resize(img, new_size, self.interpolation)
|
33 |
+
pad_h = self.max_size - new_size[0]
|
34 |
+
pad_w = self.max_size - new_size[1]
|
35 |
+
img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
|
36 |
+
return img
|
37 |
+
|
38 |
+
|
39 |
+
def _convert_to_rgb(image):
|
40 |
+
return image.convert('RGB')
|
41 |
+
|
42 |
+
|
43 |
+
# class CatGen(nn.Module):
|
44 |
+
# def __init__(self, num=4):
|
45 |
+
# self.num = num
|
46 |
+
# def mixgen_batch(image, text):
|
47 |
+
# batch_size = image.shape[0]
|
48 |
+
# index = np.random.permutation(batch_size)
|
49 |
+
|
50 |
+
# cat_images = []
|
51 |
+
# for i in range(batch_size):
|
52 |
+
# # image mixup
|
53 |
+
# image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
|
54 |
+
# # text concat
|
55 |
+
# text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
|
56 |
+
# text = torch.stack(text)
|
57 |
+
# return image, text
|
58 |
+
|
59 |
+
|
60 |
+
def image_transform(
|
61 |
+
image_size: int,
|
62 |
+
is_train: bool,
|
63 |
+
mean: Optional[Tuple[float, ...]] = None,
|
64 |
+
std: Optional[Tuple[float, ...]] = None,
|
65 |
+
resize_longest_max: bool = False,
|
66 |
+
fill_color: int = 0,
|
67 |
+
):
|
68 |
+
mean = mean or OPENAI_DATASET_MEAN
|
69 |
+
if not isinstance(mean, (list, tuple)):
|
70 |
+
mean = (mean,) * 3
|
71 |
+
|
72 |
+
std = std or OPENAI_DATASET_STD
|
73 |
+
if not isinstance(std, (list, tuple)):
|
74 |
+
std = (std,) * 3
|
75 |
+
|
76 |
+
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
|
77 |
+
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
|
78 |
+
image_size = image_size[0]
|
79 |
+
|
80 |
+
normalize = Normalize(mean=mean, std=std)
|
81 |
+
if is_train:
|
82 |
+
return Compose([
|
83 |
+
RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
|
84 |
+
_convert_to_rgb,
|
85 |
+
ToTensor(),
|
86 |
+
normalize,
|
87 |
+
])
|
88 |
+
else:
|
89 |
+
if resize_longest_max:
|
90 |
+
transforms = [
|
91 |
+
ResizeMaxSize(image_size, fill=fill_color)
|
92 |
+
]
|
93 |
+
else:
|
94 |
+
transforms = [
|
95 |
+
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
|
96 |
+
CenterCrop(image_size),
|
97 |
+
]
|
98 |
+
transforms.extend([
|
99 |
+
_convert_to_rgb,
|
100 |
+
ToTensor(),
|
101 |
+
normalize,
|
102 |
+
])
|
103 |
+
return Compose(transforms)
|
eva_clip/transformer.py
ADDED
@@ -0,0 +1,737 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
import math
|
5 |
+
from typing import Callable, Optional, Sequence
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn import functional as F
|
10 |
+
|
11 |
+
try:
|
12 |
+
from timm.models.layers import trunc_normal_
|
13 |
+
except:
|
14 |
+
from timm.layers import trunc_normal_
|
15 |
+
|
16 |
+
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
|
17 |
+
from .utils import to_2tuple
|
18 |
+
|
19 |
+
if os.getenv('ENV_TYPE') == 'deepspeed':
|
20 |
+
try:
|
21 |
+
import deepspeed
|
22 |
+
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
|
23 |
+
except:
|
24 |
+
print("Please 'pip install deepspeed'")
|
25 |
+
deepspeed = None
|
26 |
+
from torch.utils.checkpoint import checkpoint
|
27 |
+
else:
|
28 |
+
from torch.utils.checkpoint import checkpoint
|
29 |
+
|
30 |
+
try:
|
31 |
+
import xformers.ops as xops
|
32 |
+
except ImportError:
|
33 |
+
xops = None
|
34 |
+
print("Please 'pip install xformers'")
|
35 |
+
|
36 |
+
class LayerNormFp32(nn.LayerNorm):
|
37 |
+
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
38 |
+
def __init__(self, *args, **kwargs):
|
39 |
+
super().__init__(*args, **kwargs)
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
output = F.layer_norm(
|
43 |
+
x.float(),
|
44 |
+
self.normalized_shape,
|
45 |
+
self.weight.float() if self.weight is not None else None,
|
46 |
+
self.bias.float() if self.bias is not None else None,
|
47 |
+
self.eps,
|
48 |
+
)
|
49 |
+
return output.type_as(x)
|
50 |
+
|
51 |
+
|
52 |
+
class LayerNorm(nn.LayerNorm):
|
53 |
+
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
54 |
+
|
55 |
+
def forward(self, x: torch.Tensor):
|
56 |
+
orig_type = x.dtype
|
57 |
+
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
58 |
+
return x.to(orig_type)
|
59 |
+
|
60 |
+
class QuickGELU(nn.Module):
|
61 |
+
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
|
62 |
+
def forward(self, x: torch.Tensor):
|
63 |
+
return x * torch.sigmoid(1.702 * x)
|
64 |
+
|
65 |
+
|
66 |
+
class LayerScale(nn.Module):
|
67 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
68 |
+
super().__init__()
|
69 |
+
self.inplace = inplace
|
70 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
74 |
+
|
75 |
+
class PatchDropout(nn.Module):
|
76 |
+
"""
|
77 |
+
https://arxiv.org/abs/2212.00794
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(self, prob, exclude_first_token=True):
|
81 |
+
super().__init__()
|
82 |
+
assert 0 <= prob < 1.
|
83 |
+
self.prob = prob
|
84 |
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
85 |
+
logging.info(f"os.getenv('RoPE')={os.getenv('RoPE')}")
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
if not self.training or self.prob == 0.:
|
89 |
+
return x
|
90 |
+
|
91 |
+
if self.exclude_first_token:
|
92 |
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
93 |
+
else:
|
94 |
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
95 |
+
|
96 |
+
batch = x.size()[0]
|
97 |
+
num_tokens = x.size()[1]
|
98 |
+
|
99 |
+
batch_indices = torch.arange(batch)
|
100 |
+
batch_indices = batch_indices[..., None]
|
101 |
+
|
102 |
+
keep_prob = 1 - self.prob
|
103 |
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
104 |
+
|
105 |
+
rand = torch.randn(batch, num_tokens)
|
106 |
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
107 |
+
|
108 |
+
x = x[batch_indices, patch_indices_keep]
|
109 |
+
|
110 |
+
if self.exclude_first_token:
|
111 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
112 |
+
|
113 |
+
if self.training and os.getenv('RoPE') == '1':
|
114 |
+
return x, patch_indices_keep
|
115 |
+
|
116 |
+
return x
|
117 |
+
|
118 |
+
|
119 |
+
def _in_projection_packed(
|
120 |
+
q: torch.Tensor,
|
121 |
+
k: torch.Tensor,
|
122 |
+
v: torch.Tensor,
|
123 |
+
w: torch.Tensor,
|
124 |
+
b: Optional[torch.Tensor] = None,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
https://github.com/pytorch/pytorch/blob/db2a237763eb8693a20788be94f8c192e762baa8/torch/nn/functional.py#L4726
|
128 |
+
"""
|
129 |
+
E = q.size(-1)
|
130 |
+
if k is v:
|
131 |
+
if q is k:
|
132 |
+
# self-attention
|
133 |
+
return F.linear(q, w, b).chunk(3, dim=-1)
|
134 |
+
else:
|
135 |
+
# encoder-decoder attention
|
136 |
+
w_q, w_kv = w.split([E, E * 2])
|
137 |
+
if b is None:
|
138 |
+
b_q = b_kv = None
|
139 |
+
else:
|
140 |
+
b_q, b_kv = b.split([E, E * 2])
|
141 |
+
return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1)
|
142 |
+
else:
|
143 |
+
w_q, w_k, w_v = w.chunk(3)
|
144 |
+
if b is None:
|
145 |
+
b_q = b_k = b_v = None
|
146 |
+
else:
|
147 |
+
b_q, b_k, b_v = b.chunk(3)
|
148 |
+
return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v)
|
149 |
+
|
150 |
+
class Attention(nn.Module):
|
151 |
+
def __init__(
|
152 |
+
self,
|
153 |
+
dim,
|
154 |
+
num_heads=8,
|
155 |
+
qkv_bias=True,
|
156 |
+
scaled_cosine=False,
|
157 |
+
scale_heads=False,
|
158 |
+
logit_scale_max=math.log(1. / 0.01),
|
159 |
+
attn_drop=0.,
|
160 |
+
proj_drop=0.,
|
161 |
+
xattn=False,
|
162 |
+
rope=False
|
163 |
+
):
|
164 |
+
super().__init__()
|
165 |
+
self.scaled_cosine = scaled_cosine
|
166 |
+
self.scale_heads = scale_heads
|
167 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
168 |
+
self.num_heads = num_heads
|
169 |
+
self.head_dim = dim // num_heads
|
170 |
+
self.scale = self.head_dim ** -0.5
|
171 |
+
self.logit_scale_max = logit_scale_max
|
172 |
+
|
173 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
174 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
175 |
+
if qkv_bias:
|
176 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
177 |
+
else:
|
178 |
+
self.in_proj_bias = None
|
179 |
+
|
180 |
+
if self.scaled_cosine:
|
181 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
182 |
+
else:
|
183 |
+
self.logit_scale = None
|
184 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
185 |
+
if self.scale_heads:
|
186 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
187 |
+
else:
|
188 |
+
self.head_scale = None
|
189 |
+
self.out_proj = nn.Linear(dim, dim)
|
190 |
+
self.out_drop = nn.Dropout(proj_drop)
|
191 |
+
self.xattn = xattn
|
192 |
+
self.xattn_drop = attn_drop
|
193 |
+
self.rope = rope
|
194 |
+
|
195 |
+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
|
196 |
+
L, N, C = x.shape
|
197 |
+
q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1)
|
198 |
+
if self.xattn:
|
199 |
+
q = q.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
200 |
+
k = k.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
201 |
+
v = v.contiguous().view(L, N, self.num_heads, -1).transpose(0, 1)
|
202 |
+
|
203 |
+
x = xops.memory_efficient_attention(
|
204 |
+
q, k, v,
|
205 |
+
p=self.xattn_drop,
|
206 |
+
scale=self.scale if self.logit_scale is None else None,
|
207 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None,
|
208 |
+
)
|
209 |
+
else:
|
210 |
+
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
211 |
+
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
212 |
+
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
213 |
+
|
214 |
+
if self.logit_scale is not None:
|
215 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
216 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
217 |
+
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
218 |
+
attn = attn.view(-1, L, L)
|
219 |
+
else:
|
220 |
+
q = q * self.scale
|
221 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
222 |
+
|
223 |
+
if attn_mask is not None:
|
224 |
+
if attn_mask.dtype == torch.bool:
|
225 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
226 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
227 |
+
attn_mask = new_attn_mask
|
228 |
+
attn += attn_mask
|
229 |
+
|
230 |
+
attn = attn.softmax(dim=-1)
|
231 |
+
attn = self.attn_drop(attn)
|
232 |
+
|
233 |
+
x = torch.bmm(attn, v)
|
234 |
+
|
235 |
+
if self.head_scale is not None:
|
236 |
+
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
237 |
+
x = x.view(-1, L, C)
|
238 |
+
x = x.transpose(0, 1).reshape(L, N, C)
|
239 |
+
x = self.out_proj(x)
|
240 |
+
x = self.out_drop(x)
|
241 |
+
return x
|
242 |
+
|
243 |
+
class CustomAttention(nn.Module):
|
244 |
+
def __init__(
|
245 |
+
self,
|
246 |
+
dim,
|
247 |
+
num_heads=8,
|
248 |
+
qkv_bias=True,
|
249 |
+
scaled_cosine=True,
|
250 |
+
scale_heads=False,
|
251 |
+
logit_scale_max=math.log(1. / 0.01),
|
252 |
+
attn_drop=0.,
|
253 |
+
proj_drop=0.,
|
254 |
+
xattn=False
|
255 |
+
):
|
256 |
+
super().__init__()
|
257 |
+
self.scaled_cosine = scaled_cosine
|
258 |
+
self.scale_heads = scale_heads
|
259 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
260 |
+
self.num_heads = num_heads
|
261 |
+
self.head_dim = dim // num_heads
|
262 |
+
self.scale = self.head_dim ** -0.5
|
263 |
+
self.logit_scale_max = logit_scale_max
|
264 |
+
|
265 |
+
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
266 |
+
self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
267 |
+
if qkv_bias:
|
268 |
+
self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
269 |
+
else:
|
270 |
+
self.in_proj_bias = None
|
271 |
+
|
272 |
+
if self.scaled_cosine:
|
273 |
+
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
274 |
+
else:
|
275 |
+
self.logit_scale = None
|
276 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
277 |
+
if self.scale_heads:
|
278 |
+
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
279 |
+
else:
|
280 |
+
self.head_scale = None
|
281 |
+
self.out_proj = nn.Linear(dim, dim)
|
282 |
+
self.out_drop = nn.Dropout(proj_drop)
|
283 |
+
self.xattn = xattn
|
284 |
+
self.xattn_drop = attn_drop
|
285 |
+
|
286 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
287 |
+
q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight, self.in_proj_bias)
|
288 |
+
N_q, B_q, C_q = q.shape
|
289 |
+
N_k, B_k, C_k = k.shape
|
290 |
+
N_v, B_v, C_v = v.shape
|
291 |
+
if self.xattn:
|
292 |
+
# B, N, C -> B, N, num_heads, C
|
293 |
+
q = q.permute(1, 0, 2).reshape(B_q, N_q, self.num_heads, -1)
|
294 |
+
k = k.permute(1, 0, 2).reshape(B_k, N_k, self.num_heads, -1)
|
295 |
+
v = v.permute(1, 0, 2).reshape(B_v, N_v, self.num_heads, -1)
|
296 |
+
|
297 |
+
x = xops.memory_efficient_attention(
|
298 |
+
q, k, v,
|
299 |
+
p=self.xattn_drop,
|
300 |
+
scale=self.scale if self.logit_scale is None else None,
|
301 |
+
attn_bias=xops.LowerTriangularMask() if attn_mask is not None else None
|
302 |
+
)
|
303 |
+
else:
|
304 |
+
# B*H, L, C
|
305 |
+
q = q.contiguous().view(N_q, B_q * self.num_heads, -1).transpose(0, 1)
|
306 |
+
k = k.contiguous().view(N_k, B_k * self.num_heads, -1).transpose(0, 1)
|
307 |
+
v = v.contiguous().view(N_v, B_v * self.num_heads, -1).transpose(0, 1)
|
308 |
+
|
309 |
+
if self.logit_scale is not None:
|
310 |
+
# B*H, N_q, N_k
|
311 |
+
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
312 |
+
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
313 |
+
attn = attn.view(B_q, self.num_heads, N_q, N_k) * logit_scale
|
314 |
+
attn = attn.view(-1, N_q, N_k)
|
315 |
+
else:
|
316 |
+
q = q * self.scale
|
317 |
+
attn = torch.bmm(q, k.transpose(-1, -2))
|
318 |
+
|
319 |
+
if attn_mask is not None:
|
320 |
+
if attn_mask.dtype == torch.bool:
|
321 |
+
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
322 |
+
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
323 |
+
attn_mask = new_attn_mask
|
324 |
+
attn += attn_mask
|
325 |
+
|
326 |
+
attn = attn.softmax(dim=-1)
|
327 |
+
attn = self.attn_drop(attn)
|
328 |
+
|
329 |
+
x = torch.bmm(attn, v)
|
330 |
+
|
331 |
+
if self.head_scale is not None:
|
332 |
+
x = x.view(B_q, self.num_heads, N_q, C_q) * self.head_scale
|
333 |
+
x = x.view(-1, N_q, C_q)
|
334 |
+
x = x.transpose(0, 1).reshape(N_q, B_q, C_q)
|
335 |
+
x = self.out_proj(x)
|
336 |
+
x = self.out_drop(x)
|
337 |
+
return x
|
338 |
+
|
339 |
+
class CustomResidualAttentionBlock(nn.Module):
|
340 |
+
def __init__(
|
341 |
+
self,
|
342 |
+
d_model: int,
|
343 |
+
n_head: int,
|
344 |
+
mlp_ratio: float = 4.0,
|
345 |
+
ls_init_value: float = None,
|
346 |
+
act_layer: Callable = nn.GELU,
|
347 |
+
norm_layer: Callable = LayerNorm,
|
348 |
+
scale_cosine_attn: bool = False,
|
349 |
+
scale_heads: bool = False,
|
350 |
+
scale_attn: bool = False,
|
351 |
+
scale_fc: bool = False,
|
352 |
+
cross_attn: bool = False,
|
353 |
+
xattn: bool = False,
|
354 |
+
):
|
355 |
+
super().__init__()
|
356 |
+
|
357 |
+
self.ln_1 = norm_layer(d_model)
|
358 |
+
self.ln_1_k = norm_layer(d_model) if cross_attn else self.ln_1
|
359 |
+
self.ln_1_v = norm_layer(d_model) if cross_attn else self.ln_1
|
360 |
+
self.attn = CustomAttention(
|
361 |
+
d_model, n_head,
|
362 |
+
qkv_bias=True,
|
363 |
+
attn_drop=0.,
|
364 |
+
proj_drop=0.,
|
365 |
+
scaled_cosine=scale_cosine_attn,
|
366 |
+
scale_heads=scale_heads,
|
367 |
+
xattn=xattn
|
368 |
+
)
|
369 |
+
|
370 |
+
self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity()
|
371 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
372 |
+
|
373 |
+
self.ln_2 = norm_layer(d_model)
|
374 |
+
mlp_width = int(d_model * mlp_ratio)
|
375 |
+
self.mlp = nn.Sequential(OrderedDict([
|
376 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
377 |
+
('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()),
|
378 |
+
("gelu", act_layer()),
|
379 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
380 |
+
]))
|
381 |
+
|
382 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
383 |
+
|
384 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
385 |
+
q = q + self.ls_1(self.ln_attn(self.attn(self.ln_1(q), self.ln_1_k(k), self.ln_1_v(v), attn_mask=attn_mask)))
|
386 |
+
q = q + self.ls_2(self.mlp(self.ln_2(q)))
|
387 |
+
return q
|
388 |
+
|
389 |
+
class CustomTransformer(nn.Module):
|
390 |
+
def __init__(
|
391 |
+
self,
|
392 |
+
width: int,
|
393 |
+
layers: int,
|
394 |
+
heads: int,
|
395 |
+
mlp_ratio: float = 4.0,
|
396 |
+
ls_init_value: float = None,
|
397 |
+
act_layer: Callable = nn.GELU,
|
398 |
+
norm_layer: Callable = LayerNorm,
|
399 |
+
scale_cosine_attn: bool = True,
|
400 |
+
scale_heads: bool = False,
|
401 |
+
scale_attn: bool = False,
|
402 |
+
scale_fc: bool = False,
|
403 |
+
cross_attn: bool = False,
|
404 |
+
xattn: bool = False,
|
405 |
+
):
|
406 |
+
super().__init__()
|
407 |
+
self.width = width
|
408 |
+
self.layers = layers
|
409 |
+
self.grad_checkpointing = False
|
410 |
+
self.xattn = xattn
|
411 |
+
|
412 |
+
self.resblocks = nn.ModuleList([
|
413 |
+
CustomResidualAttentionBlock(
|
414 |
+
width,
|
415 |
+
heads,
|
416 |
+
mlp_ratio,
|
417 |
+
ls_init_value=ls_init_value,
|
418 |
+
act_layer=act_layer,
|
419 |
+
norm_layer=norm_layer,
|
420 |
+
scale_cosine_attn=scale_cosine_attn,
|
421 |
+
scale_heads=scale_heads,
|
422 |
+
scale_attn=scale_attn,
|
423 |
+
scale_fc=scale_fc,
|
424 |
+
cross_attn=cross_attn,
|
425 |
+
xattn=xattn)
|
426 |
+
for _ in range(layers)
|
427 |
+
])
|
428 |
+
|
429 |
+
def get_cast_dtype(self) -> torch.dtype:
|
430 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
431 |
+
|
432 |
+
def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None, attn_mask: Optional[torch.Tensor] = None):
|
433 |
+
if k is None and v is None:
|
434 |
+
k = v = q
|
435 |
+
for r in self.resblocks:
|
436 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
437 |
+
q = checkpoint(r, q, k, v, attn_mask)
|
438 |
+
else:
|
439 |
+
q = r(q, k, v, attn_mask=attn_mask)
|
440 |
+
return q
|
441 |
+
|
442 |
+
|
443 |
+
class ResidualAttentionBlock(nn.Module):
|
444 |
+
def __init__(
|
445 |
+
self,
|
446 |
+
d_model: int,
|
447 |
+
n_head: int,
|
448 |
+
mlp_ratio: float = 4.0,
|
449 |
+
ls_init_value: float = None,
|
450 |
+
act_layer: Callable = nn.GELU,
|
451 |
+
norm_layer: Callable = LayerNorm,
|
452 |
+
xattn: bool = False,
|
453 |
+
):
|
454 |
+
super().__init__()
|
455 |
+
|
456 |
+
self.ln_1 = norm_layer(d_model)
|
457 |
+
if xattn:
|
458 |
+
self.attn = Attention(d_model, n_head, xattn=True)
|
459 |
+
else:
|
460 |
+
self.attn = nn.MultiheadAttention(d_model, n_head)
|
461 |
+
self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
462 |
+
|
463 |
+
self.ln_2 = norm_layer(d_model)
|
464 |
+
mlp_width = int(d_model * mlp_ratio)
|
465 |
+
self.mlp = nn.Sequential(OrderedDict([
|
466 |
+
("c_fc", nn.Linear(d_model, mlp_width)),
|
467 |
+
("gelu", act_layer()),
|
468 |
+
("c_proj", nn.Linear(mlp_width, d_model))
|
469 |
+
]))
|
470 |
+
|
471 |
+
self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity()
|
472 |
+
self.xattn = xattn
|
473 |
+
|
474 |
+
def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
475 |
+
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
|
476 |
+
if self.xattn:
|
477 |
+
return self.attn(x, attn_mask=attn_mask)
|
478 |
+
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
|
479 |
+
|
480 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
481 |
+
x = x + self.ls_1(self.attention(self.ln_1(x), attn_mask=attn_mask))
|
482 |
+
x = x + self.ls_2(self.mlp(self.ln_2(x)))
|
483 |
+
return x
|
484 |
+
|
485 |
+
class Transformer(nn.Module):
|
486 |
+
def __init__(
|
487 |
+
self,
|
488 |
+
width: int,
|
489 |
+
layers: int,
|
490 |
+
heads: int,
|
491 |
+
mlp_ratio: float = 4.0,
|
492 |
+
ls_init_value: float = None,
|
493 |
+
act_layer: Callable = nn.GELU,
|
494 |
+
norm_layer: Callable = LayerNorm,
|
495 |
+
xattn: bool = False,
|
496 |
+
):
|
497 |
+
super().__init__()
|
498 |
+
self.width = width
|
499 |
+
self.layers = layers
|
500 |
+
self.grad_checkpointing = False
|
501 |
+
|
502 |
+
self.resblocks = nn.ModuleList([
|
503 |
+
ResidualAttentionBlock(
|
504 |
+
width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer, xattn=xattn)
|
505 |
+
for _ in range(layers)
|
506 |
+
])
|
507 |
+
|
508 |
+
def get_cast_dtype(self) -> torch.dtype:
|
509 |
+
return self.resblocks[0].mlp.c_fc.weight.dtype
|
510 |
+
|
511 |
+
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
512 |
+
for r in self.resblocks:
|
513 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
514 |
+
x = checkpoint(r, x, attn_mask)
|
515 |
+
else:
|
516 |
+
x = r(x, attn_mask=attn_mask)
|
517 |
+
return x
|
518 |
+
|
519 |
+
|
520 |
+
class VisionTransformer(nn.Module):
|
521 |
+
def __init__(
|
522 |
+
self,
|
523 |
+
image_size: int,
|
524 |
+
patch_size: int,
|
525 |
+
width: int,
|
526 |
+
layers: int,
|
527 |
+
heads: int,
|
528 |
+
mlp_ratio: float,
|
529 |
+
ls_init_value: float = None,
|
530 |
+
patch_dropout: float = 0.,
|
531 |
+
global_average_pool: bool = False,
|
532 |
+
output_dim: int = 512,
|
533 |
+
act_layer: Callable = nn.GELU,
|
534 |
+
norm_layer: Callable = LayerNorm,
|
535 |
+
xattn: bool = False,
|
536 |
+
):
|
537 |
+
super().__init__()
|
538 |
+
self.image_size = to_2tuple(image_size)
|
539 |
+
self.patch_size = to_2tuple(patch_size)
|
540 |
+
self.grid_size = (self.image_size[0] // self.patch_size[0], self.image_size[1] // self.patch_size[1])
|
541 |
+
self.output_dim = output_dim
|
542 |
+
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
543 |
+
|
544 |
+
scale = width ** -0.5
|
545 |
+
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
546 |
+
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
547 |
+
|
548 |
+
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
|
549 |
+
self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
|
550 |
+
self.ln_pre = norm_layer(width)
|
551 |
+
|
552 |
+
self.transformer = Transformer(
|
553 |
+
width,
|
554 |
+
layers,
|
555 |
+
heads,
|
556 |
+
mlp_ratio,
|
557 |
+
ls_init_value=ls_init_value,
|
558 |
+
act_layer=act_layer,
|
559 |
+
norm_layer=norm_layer,
|
560 |
+
xattn=xattn
|
561 |
+
)
|
562 |
+
|
563 |
+
self.global_average_pool = global_average_pool
|
564 |
+
self.ln_post = norm_layer(width)
|
565 |
+
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
566 |
+
|
567 |
+
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
|
568 |
+
for param in self.parameters():
|
569 |
+
param.requires_grad = False
|
570 |
+
|
571 |
+
if unlocked_groups != 0:
|
572 |
+
groups = [
|
573 |
+
[
|
574 |
+
self.conv1,
|
575 |
+
self.class_embedding,
|
576 |
+
self.positional_embedding,
|
577 |
+
self.ln_pre,
|
578 |
+
],
|
579 |
+
*self.transformer.resblocks[:-1],
|
580 |
+
[
|
581 |
+
self.transformer.resblocks[-1],
|
582 |
+
self.ln_post,
|
583 |
+
],
|
584 |
+
self.proj,
|
585 |
+
]
|
586 |
+
|
587 |
+
def _unlock(x):
|
588 |
+
if isinstance(x, Sequence):
|
589 |
+
for g in x:
|
590 |
+
_unlock(g)
|
591 |
+
else:
|
592 |
+
if isinstance(x, torch.nn.Parameter):
|
593 |
+
x.requires_grad = True
|
594 |
+
else:
|
595 |
+
for p in x.parameters():
|
596 |
+
p.requires_grad = True
|
597 |
+
|
598 |
+
_unlock(groups[-unlocked_groups:])
|
599 |
+
|
600 |
+
def get_num_layers(self):
|
601 |
+
return self.transformer.layers
|
602 |
+
|
603 |
+
@torch.jit.ignore
|
604 |
+
def set_grad_checkpointing(self, enable=True):
|
605 |
+
self.transformer.grad_checkpointing = enable
|
606 |
+
|
607 |
+
@torch.jit.ignore
|
608 |
+
def no_weight_decay(self):
|
609 |
+
return {'positional_embedding', 'class_embedding'}
|
610 |
+
|
611 |
+
def forward(self, x: torch.Tensor, return_all_features: bool=False):
|
612 |
+
x = self.conv1(x) # shape = [*, width, grid, grid]
|
613 |
+
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
614 |
+
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
615 |
+
x = torch.cat(
|
616 |
+
[self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
617 |
+
x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
618 |
+
x = x + self.positional_embedding.to(x.dtype)
|
619 |
+
|
620 |
+
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
621 |
+
x = self.patch_dropout(x)
|
622 |
+
x = self.ln_pre(x)
|
623 |
+
|
624 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
625 |
+
x = self.transformer(x)
|
626 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
627 |
+
|
628 |
+
if not return_all_features:
|
629 |
+
if self.global_average_pool:
|
630 |
+
x = x.mean(dim=1) #x = x[:,1:,:].mean(dim=1)
|
631 |
+
else:
|
632 |
+
x = x[:, 0]
|
633 |
+
|
634 |
+
x = self.ln_post(x)
|
635 |
+
|
636 |
+
if self.proj is not None:
|
637 |
+
x = x @ self.proj
|
638 |
+
|
639 |
+
return x
|
640 |
+
|
641 |
+
|
642 |
+
class TextTransformer(nn.Module):
|
643 |
+
def __init__(
|
644 |
+
self,
|
645 |
+
context_length: int = 77,
|
646 |
+
vocab_size: int = 49408,
|
647 |
+
width: int = 512,
|
648 |
+
heads: int = 8,
|
649 |
+
layers: int = 12,
|
650 |
+
ls_init_value: float = None,
|
651 |
+
output_dim: int = 512,
|
652 |
+
act_layer: Callable = nn.GELU,
|
653 |
+
norm_layer: Callable = LayerNorm,
|
654 |
+
xattn: bool= False,
|
655 |
+
attn_mask: bool = True
|
656 |
+
):
|
657 |
+
super().__init__()
|
658 |
+
self.context_length = context_length
|
659 |
+
self.vocab_size = vocab_size
|
660 |
+
self.width = width
|
661 |
+
self.output_dim = output_dim
|
662 |
+
|
663 |
+
self.token_embedding = nn.Embedding(vocab_size, width)
|
664 |
+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, width))
|
665 |
+
self.transformer = Transformer(
|
666 |
+
width=width,
|
667 |
+
layers=layers,
|
668 |
+
heads=heads,
|
669 |
+
ls_init_value=ls_init_value,
|
670 |
+
act_layer=act_layer,
|
671 |
+
norm_layer=norm_layer,
|
672 |
+
xattn=xattn
|
673 |
+
)
|
674 |
+
|
675 |
+
self.xattn = xattn
|
676 |
+
self.ln_final = norm_layer(width)
|
677 |
+
self.text_projection = nn.Parameter(torch.empty(width, output_dim))
|
678 |
+
|
679 |
+
if attn_mask:
|
680 |
+
self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False)
|
681 |
+
else:
|
682 |
+
self.attn_mask = None
|
683 |
+
|
684 |
+
self.init_parameters()
|
685 |
+
|
686 |
+
def init_parameters(self):
|
687 |
+
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
688 |
+
nn.init.normal_(self.positional_embedding, std=0.01)
|
689 |
+
|
690 |
+
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
691 |
+
attn_std = self.transformer.width ** -0.5
|
692 |
+
fc_std = (2 * self.transformer.width) ** -0.5
|
693 |
+
for block in self.transformer.resblocks:
|
694 |
+
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
695 |
+
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
696 |
+
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
697 |
+
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
698 |
+
|
699 |
+
if self.text_projection is not None:
|
700 |
+
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
701 |
+
|
702 |
+
@torch.jit.ignore
|
703 |
+
def set_grad_checkpointing(self, enable=True):
|
704 |
+
self.transformer.grad_checkpointing = enable
|
705 |
+
|
706 |
+
@torch.jit.ignore
|
707 |
+
def no_weight_decay(self):
|
708 |
+
# return {'positional_embedding', 'token_embedding'}
|
709 |
+
return {'positional_embedding'}
|
710 |
+
|
711 |
+
def get_num_layers(self):
|
712 |
+
return self.transformer.layers
|
713 |
+
|
714 |
+
def build_attention_mask(self):
|
715 |
+
# lazily create causal attention mask, with full attention between the vision tokens
|
716 |
+
# pytorch uses additive attention mask; fill with -inf
|
717 |
+
mask = torch.empty(self.context_length, self.context_length)
|
718 |
+
mask.fill_(float("-inf"))
|
719 |
+
mask.triu_(1) # zero out the lower diagonal
|
720 |
+
return mask
|
721 |
+
|
722 |
+
def forward(self, text, return_all_features: bool=False):
|
723 |
+
cast_dtype = self.transformer.get_cast_dtype()
|
724 |
+
x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
|
725 |
+
|
726 |
+
x = x + self.positional_embedding.to(cast_dtype)
|
727 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
728 |
+
x = self.transformer(x, attn_mask=self.attn_mask)
|
729 |
+
# x = self.transformer(x) # no attention mask is applied
|
730 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
731 |
+
x = self.ln_final(x)
|
732 |
+
|
733 |
+
if not return_all_features:
|
734 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
735 |
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
736 |
+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
737 |
+
return x
|
eva_clip/utils.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import repeat
|
2 |
+
import collections.abc
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch import nn as nn
|
9 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
# open CLIP
|
13 |
+
def resize_clip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
14 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
15 |
+
old_pos_embed = state_dict.get('visual.positional_embedding', None)
|
16 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
17 |
+
return
|
18 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
19 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
20 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
21 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
22 |
+
return
|
23 |
+
|
24 |
+
if extra_tokens:
|
25 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
26 |
+
else:
|
27 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
28 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
29 |
+
|
30 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
31 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
32 |
+
pos_emb_img = F.interpolate(
|
33 |
+
pos_emb_img,
|
34 |
+
size=grid_size,
|
35 |
+
mode=interpolation,
|
36 |
+
align_corners=True,
|
37 |
+
)
|
38 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
39 |
+
if pos_emb_tok is not None:
|
40 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
41 |
+
else:
|
42 |
+
new_pos_embed = pos_emb_img
|
43 |
+
state_dict['visual.positional_embedding'] = new_pos_embed
|
44 |
+
|
45 |
+
|
46 |
+
def resize_visual_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
47 |
+
# Rescale the grid of position embeddings when loading from state_dict
|
48 |
+
old_pos_embed = state_dict.get('positional_embedding', None)
|
49 |
+
if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
|
50 |
+
return
|
51 |
+
grid_size = to_2tuple(model.visual.grid_size)
|
52 |
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
53 |
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
54 |
+
if new_seq_len == old_pos_embed.shape[0]:
|
55 |
+
return
|
56 |
+
|
57 |
+
if extra_tokens:
|
58 |
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
59 |
+
else:
|
60 |
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
61 |
+
old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
|
62 |
+
|
63 |
+
logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
64 |
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
65 |
+
pos_emb_img = F.interpolate(
|
66 |
+
pos_emb_img,
|
67 |
+
size=grid_size,
|
68 |
+
mode=interpolation,
|
69 |
+
align_corners=True,
|
70 |
+
)
|
71 |
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
72 |
+
if pos_emb_tok is not None:
|
73 |
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
74 |
+
else:
|
75 |
+
new_pos_embed = pos_emb_img
|
76 |
+
state_dict['positional_embedding'] = new_pos_embed
|
77 |
+
|
78 |
+
def resize_evaclip_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
79 |
+
all_keys = list(state_dict.keys())
|
80 |
+
# interpolate position embedding
|
81 |
+
if 'visual.pos_embed' in state_dict:
|
82 |
+
pos_embed_checkpoint = state_dict['visual.pos_embed']
|
83 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
84 |
+
num_patches = model.visual.patch_embed.num_patches
|
85 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
86 |
+
# height (== width) for the checkpoint position embedding
|
87 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
88 |
+
# height (== width) for the new position embedding
|
89 |
+
new_size = int(num_patches ** 0.5)
|
90 |
+
# class_token and dist_token are kept unchanged
|
91 |
+
if orig_size != new_size:
|
92 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
93 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
94 |
+
# only the position tokens are interpolated
|
95 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
96 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
97 |
+
pos_tokens = torch.nn.functional.interpolate(
|
98 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
99 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
100 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
101 |
+
state_dict['visual.pos_embed'] = new_pos_embed
|
102 |
+
|
103 |
+
patch_embed_proj = state_dict['visual.patch_embed.proj.weight']
|
104 |
+
patch_size = model.visual.patch_embed.patch_size
|
105 |
+
state_dict['visual.patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
106 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
107 |
+
|
108 |
+
|
109 |
+
def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
110 |
+
all_keys = list(state_dict.keys())
|
111 |
+
# interpolate position embedding
|
112 |
+
if 'pos_embed' in state_dict:
|
113 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
114 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
115 |
+
num_patches = model.visual.patch_embed.num_patches
|
116 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
117 |
+
# height (== width) for the checkpoint position embedding
|
118 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
119 |
+
# height (== width) for the new position embedding
|
120 |
+
new_size = int(num_patches ** 0.5)
|
121 |
+
# class_token and dist_token are kept unchanged
|
122 |
+
if orig_size != new_size:
|
123 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
124 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
125 |
+
# only the position tokens are interpolated
|
126 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
127 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
128 |
+
pos_tokens = torch.nn.functional.interpolate(
|
129 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
130 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
131 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
132 |
+
state_dict['pos_embed'] = new_pos_embed
|
133 |
+
|
134 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
135 |
+
patch_size = model.visual.patch_embed.patch_size
|
136 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
137 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
138 |
+
|
139 |
+
|
140 |
+
def resize_rel_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1):
|
141 |
+
all_keys = list(state_dict.keys())
|
142 |
+
for key in all_keys:
|
143 |
+
if "relative_position_index" in key:
|
144 |
+
state_dict.pop(key)
|
145 |
+
|
146 |
+
if "relative_position_bias_table" in key:
|
147 |
+
rel_pos_bias = state_dict[key]
|
148 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
149 |
+
dst_num_pos, _ = model.visual.state_dict()[key].size()
|
150 |
+
dst_patch_shape = model.visual.patch_embed.patch_shape
|
151 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
152 |
+
raise NotImplementedError()
|
153 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
154 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
155 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
156 |
+
if src_size != dst_size:
|
157 |
+
print("Position interpolate for %s from %dx%d to %dx%d" % (
|
158 |
+
key, src_size, src_size, dst_size, dst_size))
|
159 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
160 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
161 |
+
|
162 |
+
def geometric_progression(a, r, n):
|
163 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
164 |
+
|
165 |
+
left, right = 1.01, 1.5
|
166 |
+
while right - left > 1e-6:
|
167 |
+
q = (left + right) / 2.0
|
168 |
+
gp = geometric_progression(1, q, src_size // 2)
|
169 |
+
if gp > dst_size // 2:
|
170 |
+
right = q
|
171 |
+
else:
|
172 |
+
left = q
|
173 |
+
|
174 |
+
# if q > 1.090307:
|
175 |
+
# q = 1.090307
|
176 |
+
|
177 |
+
dis = []
|
178 |
+
cur = 1
|
179 |
+
for i in range(src_size // 2):
|
180 |
+
dis.append(cur)
|
181 |
+
cur += q ** (i + 1)
|
182 |
+
|
183 |
+
r_ids = [-_ for _ in reversed(dis)]
|
184 |
+
|
185 |
+
x = r_ids + [0] + dis
|
186 |
+
y = r_ids + [0] + dis
|
187 |
+
|
188 |
+
t = dst_size // 2.0
|
189 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
190 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
191 |
+
|
192 |
+
print("Original positions = %s" % str(x))
|
193 |
+
print("Target positions = %s" % str(dx))
|
194 |
+
|
195 |
+
all_rel_pos_bias = []
|
196 |
+
|
197 |
+
for i in range(num_attn_heads):
|
198 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
199 |
+
f = F.interpolate.interp2d(x, y, z, kind='cubic')
|
200 |
+
all_rel_pos_bias.append(
|
201 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
202 |
+
|
203 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
204 |
+
|
205 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
206 |
+
state_dict[key] = new_rel_pos_bias
|
207 |
+
|
208 |
+
# interpolate position embedding
|
209 |
+
if 'pos_embed' in state_dict:
|
210 |
+
pos_embed_checkpoint = state_dict['pos_embed']
|
211 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
212 |
+
num_patches = model.visual.patch_embed.num_patches
|
213 |
+
num_extra_tokens = model.visual.pos_embed.shape[-2] - num_patches
|
214 |
+
# height (== width) for the checkpoint position embedding
|
215 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
216 |
+
# height (== width) for the new position embedding
|
217 |
+
new_size = int(num_patches ** 0.5)
|
218 |
+
# class_token and dist_token are kept unchanged
|
219 |
+
if orig_size != new_size:
|
220 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
221 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
222 |
+
# only the position tokens are interpolated
|
223 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
224 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
225 |
+
pos_tokens = torch.nn.functional.interpolate(
|
226 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
227 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
228 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
229 |
+
state_dict['pos_embed'] = new_pos_embed
|
230 |
+
|
231 |
+
patch_embed_proj = state_dict['patch_embed.proj.weight']
|
232 |
+
patch_size = model.visual.patch_embed.patch_size
|
233 |
+
state_dict['patch_embed.proj.weight'] = torch.nn.functional.interpolate(
|
234 |
+
patch_embed_proj.float(), size=patch_size, mode='bicubic', align_corners=False)
|
235 |
+
|
236 |
+
|
237 |
+
def freeze_batch_norm_2d(module, module_match={}, name=''):
|
238 |
+
"""
|
239 |
+
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
|
240 |
+
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
|
241 |
+
returned. Otherwise, the module is walked recursively and submodules are converted in place.
|
242 |
+
|
243 |
+
Args:
|
244 |
+
module (torch.nn.Module): Any PyTorch module.
|
245 |
+
module_match (dict): Dictionary of full module names to freeze (all if empty)
|
246 |
+
name (str): Full module name (prefix)
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
torch.nn.Module: Resulting module
|
250 |
+
|
251 |
+
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
252 |
+
"""
|
253 |
+
res = module
|
254 |
+
is_match = True
|
255 |
+
if module_match:
|
256 |
+
is_match = name in module_match
|
257 |
+
if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
|
258 |
+
res = FrozenBatchNorm2d(module.num_features)
|
259 |
+
res.num_features = module.num_features
|
260 |
+
res.affine = module.affine
|
261 |
+
if module.affine:
|
262 |
+
res.weight.data = module.weight.data.clone().detach()
|
263 |
+
res.bias.data = module.bias.data.clone().detach()
|
264 |
+
res.running_mean.data = module.running_mean.data
|
265 |
+
res.running_var.data = module.running_var.data
|
266 |
+
res.eps = module.eps
|
267 |
+
else:
|
268 |
+
for child_name, child in module.named_children():
|
269 |
+
full_child_name = '.'.join([name, child_name]) if name else child_name
|
270 |
+
new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
|
271 |
+
if new_child is not child:
|
272 |
+
res.add_module(child_name, new_child)
|
273 |
+
return res
|
274 |
+
|
275 |
+
|
276 |
+
# From PyTorch internals
|
277 |
+
def _ntuple(n):
|
278 |
+
def parse(x):
|
279 |
+
if isinstance(x, collections.abc.Iterable):
|
280 |
+
return x
|
281 |
+
return tuple(repeat(x, n))
|
282 |
+
return parse
|
283 |
+
|
284 |
+
|
285 |
+
to_1tuple = _ntuple(1)
|
286 |
+
to_2tuple = _ntuple(2)
|
287 |
+
to_3tuple = _ntuple(3)
|
288 |
+
to_4tuple = _ntuple(4)
|
289 |
+
to_ntuple = lambda n, x: _ntuple(n)(x)
|
290 |
+
|
291 |
+
|
292 |
+
def is_logging(args):
|
293 |
+
def is_global_master(args):
|
294 |
+
return args.rank == 0
|
295 |
+
|
296 |
+
def is_local_master(args):
|
297 |
+
return args.local_rank == 0
|
298 |
+
|
299 |
+
def is_master(args, local=False):
|
300 |
+
return is_local_master(args) if local else is_global_master(args)
|
301 |
+
return is_master
|
302 |
+
|
303 |
+
|
304 |
+
class AllGather(torch.autograd.Function):
|
305 |
+
"""An autograd function that performs allgather on a tensor.
|
306 |
+
Performs all_gather operation on the provided tensors.
|
307 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
308 |
+
"""
|
309 |
+
|
310 |
+
@staticmethod
|
311 |
+
def forward(ctx, tensor, rank, world_size):
|
312 |
+
tensors_gather = [torch.empty_like(tensor) for _ in range(world_size)]
|
313 |
+
torch.distributed.all_gather(tensors_gather, tensor)
|
314 |
+
ctx.rank = rank
|
315 |
+
ctx.batch_size = tensor.shape[0]
|
316 |
+
return torch.cat(tensors_gather, 0)
|
317 |
+
|
318 |
+
@staticmethod
|
319 |
+
def backward(ctx, grad_output):
|
320 |
+
return (
|
321 |
+
grad_output[ctx.batch_size * ctx.rank: ctx.batch_size * (ctx.rank + 1)],
|
322 |
+
None,
|
323 |
+
None
|
324 |
+
)
|
325 |
+
|
326 |
+
allgather = AllGather.apply
|
example_inputs/hinton.jpeg
ADDED
example_inputs/lecun.jpg
ADDED
example_inputs/lifeifei.jpg
ADDED
example_inputs/liuyifei.png
ADDED
example_inputs/pengwei.jpg
ADDED
Git LFS Details
|
example_inputs/rihanna.webp
ADDED
example_inputs/zcy.webp
ADDED
flux/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
from ._version import version as __version__ # type: ignore
|
3 |
+
from ._version import version_tuple
|
4 |
+
except ImportError:
|
5 |
+
__version__ = "unknown (no version information available)"
|
6 |
+
version_tuple = (0, 0, "unknown", "noinfo")
|
7 |
+
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
PACKAGE = __package__.replace("_", "-")
|
11 |
+
PACKAGE_ROOT = Path(__file__).parent
|
flux/math.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
|
6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
7 |
+
if pe is not None:
|
8 |
+
q, k = apply_rope(q, k, pe)
|
9 |
+
|
10 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
11 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
12 |
+
|
13 |
+
return x
|
14 |
+
|
15 |
+
|
16 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
17 |
+
assert dim % 2 == 0
|
18 |
+
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
19 |
+
omega = 1.0 / (theta**scale)
|
20 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
21 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
22 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
23 |
+
return out.float()
|
24 |
+
|
25 |
+
|
26 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
27 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
28 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
29 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
30 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
31 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
flux/model.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import Tensor, nn
|
5 |
+
|
6 |
+
from flux.modules.layers import (
|
7 |
+
DoubleStreamBlock,
|
8 |
+
EmbedND,
|
9 |
+
LastLayer,
|
10 |
+
MLPEmbedder,
|
11 |
+
SingleStreamBlock,
|
12 |
+
timestep_embedding,
|
13 |
+
)
|
14 |
+
|
15 |
+
DEVICE = torch.device("cuda")
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class FluxParams:
|
19 |
+
in_channels: int
|
20 |
+
vec_in_dim: int
|
21 |
+
context_in_dim: int
|
22 |
+
hidden_size: int
|
23 |
+
mlp_ratio: float
|
24 |
+
num_heads: int
|
25 |
+
depth: int
|
26 |
+
depth_single_blocks: int
|
27 |
+
axes_dim: list[int]
|
28 |
+
theta: int
|
29 |
+
qkv_bias: bool
|
30 |
+
guidance_embed: bool
|
31 |
+
|
32 |
+
|
33 |
+
class Flux(nn.Module):
|
34 |
+
"""
|
35 |
+
Transformer model for flow matching on sequences.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(self, params: FluxParams):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.params = params
|
42 |
+
self.in_channels = params.in_channels
|
43 |
+
self.out_channels = self.in_channels
|
44 |
+
if params.hidden_size % params.num_heads != 0:
|
45 |
+
raise ValueError(
|
46 |
+
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
47 |
+
)
|
48 |
+
pe_dim = params.hidden_size // params.num_heads
|
49 |
+
if sum(params.axes_dim) != pe_dim:
|
50 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
51 |
+
self.hidden_size = params.hidden_size
|
52 |
+
self.num_heads = params.num_heads
|
53 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
54 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
55 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
56 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
57 |
+
self.guidance_in = (
|
58 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
59 |
+
)
|
60 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
61 |
+
|
62 |
+
self.double_blocks = nn.ModuleList(
|
63 |
+
[
|
64 |
+
DoubleStreamBlock(
|
65 |
+
self.hidden_size,
|
66 |
+
self.num_heads,
|
67 |
+
mlp_ratio=params.mlp_ratio,
|
68 |
+
qkv_bias=params.qkv_bias,
|
69 |
+
)
|
70 |
+
for _ in range(params.depth)
|
71 |
+
]
|
72 |
+
)
|
73 |
+
|
74 |
+
self.single_blocks = nn.ModuleList(
|
75 |
+
[
|
76 |
+
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
77 |
+
for _ in range(params.depth_single_blocks)
|
78 |
+
]
|
79 |
+
)
|
80 |
+
|
81 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
82 |
+
|
83 |
+
self.pulid_ca = None
|
84 |
+
self.pulid_double_interval = 2
|
85 |
+
self.pulid_single_interval = 4
|
86 |
+
|
87 |
+
def forward(
|
88 |
+
self,
|
89 |
+
img: Tensor,
|
90 |
+
img_ids: Tensor,
|
91 |
+
txt: Tensor,
|
92 |
+
txt_ids: Tensor,
|
93 |
+
timesteps: Tensor,
|
94 |
+
y: Tensor,
|
95 |
+
guidance: Tensor = None,
|
96 |
+
id: Tensor = None,
|
97 |
+
id_weight: float = 1.0,
|
98 |
+
aggressive_offload: bool = False,
|
99 |
+
) -> Tensor:
|
100 |
+
if img.ndim != 3 or txt.ndim != 3:
|
101 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
102 |
+
|
103 |
+
# running on sequences img
|
104 |
+
img = self.img_in(img)
|
105 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
106 |
+
if self.params.guidance_embed:
|
107 |
+
if guidance is None:
|
108 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
109 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
110 |
+
vec = vec + self.vector_in(y)
|
111 |
+
txt = self.txt_in(txt)
|
112 |
+
|
113 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
114 |
+
pe = self.pe_embedder(ids)
|
115 |
+
|
116 |
+
ca_idx = 0
|
117 |
+
if aggressive_offload:
|
118 |
+
self.double_blocks = self.double_blocks.to(DEVICE)
|
119 |
+
for i, block in enumerate(self.double_blocks):
|
120 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
121 |
+
|
122 |
+
if i % self.pulid_double_interval == 0 and id is not None:
|
123 |
+
img = img + id_weight * self.pulid_ca[ca_idx](id, img)
|
124 |
+
ca_idx += 1
|
125 |
+
if aggressive_offload:
|
126 |
+
self.double_blocks.cpu()
|
127 |
+
|
128 |
+
img = torch.cat((txt, img), 1)
|
129 |
+
if aggressive_offload:
|
130 |
+
self.single_blocks = self.single_blocks.to(DEVICE)
|
131 |
+
for i, block in enumerate(self.single_blocks):
|
132 |
+
x = block(img, vec=vec, pe=pe)
|
133 |
+
real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]
|
134 |
+
|
135 |
+
if i % self.pulid_single_interval == 0 and id is not None:
|
136 |
+
real_img = real_img + id_weight * self.pulid_ca[ca_idx](id, real_img)
|
137 |
+
ca_idx += 1
|
138 |
+
|
139 |
+
img = torch.cat((txt, real_img), 1)
|
140 |
+
if aggressive_offload:
|
141 |
+
self.single_blocks.cpu()
|
142 |
+
img = img[:, txt.shape[1] :, ...]
|
143 |
+
|
144 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
145 |
+
return img
|
146 |
+
|
147 |
+
def components_to_gpu(self):
|
148 |
+
# everything but double_blocks, single_blocks
|
149 |
+
self.img_in.to(DEVICE)
|
150 |
+
self.time_in.to(DEVICE)
|
151 |
+
self.guidance_in.to(DEVICE)
|
152 |
+
self.vector_in.to(DEVICE)
|
153 |
+
self.txt_in.to(DEVICE)
|
154 |
+
self.pe_embedder.to(DEVICE)
|
155 |
+
self.final_layer.to(DEVICE)
|
156 |
+
if self.pulid_ca:
|
157 |
+
self.pulid_ca.to(DEVICE)
|
flux/modules/__init__.py
ADDED
File without changes
|
flux/modules/autoencoder.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from einops import rearrange
|
5 |
+
from torch import Tensor, nn
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class AutoEncoderParams:
|
10 |
+
resolution: int
|
11 |
+
in_channels: int
|
12 |
+
ch: int
|
13 |
+
out_ch: int
|
14 |
+
ch_mult: list[int]
|
15 |
+
num_res_blocks: int
|
16 |
+
z_channels: int
|
17 |
+
scale_factor: float
|
18 |
+
shift_factor: float
|
19 |
+
|
20 |
+
|
21 |
+
def swish(x: Tensor) -> Tensor:
|
22 |
+
return x * torch.sigmoid(x)
|
23 |
+
|
24 |
+
|
25 |
+
class AttnBlock(nn.Module):
|
26 |
+
def __init__(self, in_channels: int):
|
27 |
+
super().__init__()
|
28 |
+
self.in_channels = in_channels
|
29 |
+
|
30 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
31 |
+
|
32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
36 |
+
|
37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
38 |
+
h_ = self.norm(h_)
|
39 |
+
q = self.q(h_)
|
40 |
+
k = self.k(h_)
|
41 |
+
v = self.v(h_)
|
42 |
+
|
43 |
+
b, c, h, w = q.shape
|
44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
48 |
+
|
49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
50 |
+
|
51 |
+
def forward(self, x: Tensor) -> Tensor:
|
52 |
+
return x + self.proj_out(self.attention(x))
|
53 |
+
|
54 |
+
|
55 |
+
class ResnetBlock(nn.Module):
|
56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
57 |
+
super().__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
60 |
+
self.out_channels = out_channels
|
61 |
+
|
62 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
63 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
64 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
65 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
66 |
+
if self.in_channels != self.out_channels:
|
67 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
68 |
+
|
69 |
+
def forward(self, x):
|
70 |
+
h = x
|
71 |
+
h = self.norm1(h)
|
72 |
+
h = swish(h)
|
73 |
+
h = self.conv1(h)
|
74 |
+
|
75 |
+
h = self.norm2(h)
|
76 |
+
h = swish(h)
|
77 |
+
h = self.conv2(h)
|
78 |
+
|
79 |
+
if self.in_channels != self.out_channels:
|
80 |
+
x = self.nin_shortcut(x)
|
81 |
+
|
82 |
+
return x + h
|
83 |
+
|
84 |
+
|
85 |
+
class Downsample(nn.Module):
|
86 |
+
def __init__(self, in_channels: int):
|
87 |
+
super().__init__()
|
88 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
89 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
90 |
+
|
91 |
+
def forward(self, x: Tensor):
|
92 |
+
pad = (0, 1, 0, 1)
|
93 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
94 |
+
x = self.conv(x)
|
95 |
+
return x
|
96 |
+
|
97 |
+
|
98 |
+
class Upsample(nn.Module):
|
99 |
+
def __init__(self, in_channels: int):
|
100 |
+
super().__init__()
|
101 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
102 |
+
|
103 |
+
def forward(self, x: Tensor):
|
104 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
105 |
+
x = self.conv(x)
|
106 |
+
return x
|
107 |
+
|
108 |
+
|
109 |
+
class Encoder(nn.Module):
|
110 |
+
def __init__(
|
111 |
+
self,
|
112 |
+
resolution: int,
|
113 |
+
in_channels: int,
|
114 |
+
ch: int,
|
115 |
+
ch_mult: list[int],
|
116 |
+
num_res_blocks: int,
|
117 |
+
z_channels: int,
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.ch = ch
|
121 |
+
self.num_resolutions = len(ch_mult)
|
122 |
+
self.num_res_blocks = num_res_blocks
|
123 |
+
self.resolution = resolution
|
124 |
+
self.in_channels = in_channels
|
125 |
+
# downsampling
|
126 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
127 |
+
|
128 |
+
curr_res = resolution
|
129 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
130 |
+
self.in_ch_mult = in_ch_mult
|
131 |
+
self.down = nn.ModuleList()
|
132 |
+
block_in = self.ch
|
133 |
+
for i_level in range(self.num_resolutions):
|
134 |
+
block = nn.ModuleList()
|
135 |
+
attn = nn.ModuleList()
|
136 |
+
block_in = ch * in_ch_mult[i_level]
|
137 |
+
block_out = ch * ch_mult[i_level]
|
138 |
+
for _ in range(self.num_res_blocks):
|
139 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
140 |
+
block_in = block_out
|
141 |
+
down = nn.Module()
|
142 |
+
down.block = block
|
143 |
+
down.attn = attn
|
144 |
+
if i_level != self.num_resolutions - 1:
|
145 |
+
down.downsample = Downsample(block_in)
|
146 |
+
curr_res = curr_res // 2
|
147 |
+
self.down.append(down)
|
148 |
+
|
149 |
+
# middle
|
150 |
+
self.mid = nn.Module()
|
151 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
152 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
153 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
154 |
+
|
155 |
+
# end
|
156 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
157 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
158 |
+
|
159 |
+
def forward(self, x: Tensor) -> Tensor:
|
160 |
+
# downsampling
|
161 |
+
hs = [self.conv_in(x)]
|
162 |
+
for i_level in range(self.num_resolutions):
|
163 |
+
for i_block in range(self.num_res_blocks):
|
164 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
165 |
+
if len(self.down[i_level].attn) > 0:
|
166 |
+
h = self.down[i_level].attn[i_block](h)
|
167 |
+
hs.append(h)
|
168 |
+
if i_level != self.num_resolutions - 1:
|
169 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
170 |
+
|
171 |
+
# middle
|
172 |
+
h = hs[-1]
|
173 |
+
h = self.mid.block_1(h)
|
174 |
+
h = self.mid.attn_1(h)
|
175 |
+
h = self.mid.block_2(h)
|
176 |
+
# end
|
177 |
+
h = self.norm_out(h)
|
178 |
+
h = swish(h)
|
179 |
+
h = self.conv_out(h)
|
180 |
+
return h
|
181 |
+
|
182 |
+
|
183 |
+
class Decoder(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
ch: int,
|
187 |
+
out_ch: int,
|
188 |
+
ch_mult: list[int],
|
189 |
+
num_res_blocks: int,
|
190 |
+
in_channels: int,
|
191 |
+
resolution: int,
|
192 |
+
z_channels: int,
|
193 |
+
):
|
194 |
+
super().__init__()
|
195 |
+
self.ch = ch
|
196 |
+
self.num_resolutions = len(ch_mult)
|
197 |
+
self.num_res_blocks = num_res_blocks
|
198 |
+
self.resolution = resolution
|
199 |
+
self.in_channels = in_channels
|
200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
201 |
+
|
202 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
203 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
204 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
205 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
206 |
+
|
207 |
+
# z to block_in
|
208 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
209 |
+
|
210 |
+
# middle
|
211 |
+
self.mid = nn.Module()
|
212 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
213 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
214 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
215 |
+
|
216 |
+
# upsampling
|
217 |
+
self.up = nn.ModuleList()
|
218 |
+
for i_level in reversed(range(self.num_resolutions)):
|
219 |
+
block = nn.ModuleList()
|
220 |
+
attn = nn.ModuleList()
|
221 |
+
block_out = ch * ch_mult[i_level]
|
222 |
+
for _ in range(self.num_res_blocks + 1):
|
223 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
224 |
+
block_in = block_out
|
225 |
+
up = nn.Module()
|
226 |
+
up.block = block
|
227 |
+
up.attn = attn
|
228 |
+
if i_level != 0:
|
229 |
+
up.upsample = Upsample(block_in)
|
230 |
+
curr_res = curr_res * 2
|
231 |
+
self.up.insert(0, up) # prepend to get consistent order
|
232 |
+
|
233 |
+
# end
|
234 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
235 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
236 |
+
|
237 |
+
def forward(self, z: Tensor) -> Tensor:
|
238 |
+
# z to block_in
|
239 |
+
h = self.conv_in(z)
|
240 |
+
|
241 |
+
# middle
|
242 |
+
h = self.mid.block_1(h)
|
243 |
+
h = self.mid.attn_1(h)
|
244 |
+
h = self.mid.block_2(h)
|
245 |
+
|
246 |
+
# upsampling
|
247 |
+
for i_level in reversed(range(self.num_resolutions)):
|
248 |
+
for i_block in range(self.num_res_blocks + 1):
|
249 |
+
h = self.up[i_level].block[i_block](h)
|
250 |
+
if len(self.up[i_level].attn) > 0:
|
251 |
+
h = self.up[i_level].attn[i_block](h)
|
252 |
+
if i_level != 0:
|
253 |
+
h = self.up[i_level].upsample(h)
|
254 |
+
|
255 |
+
# end
|
256 |
+
h = self.norm_out(h)
|
257 |
+
h = swish(h)
|
258 |
+
h = self.conv_out(h)
|
259 |
+
return h
|
260 |
+
|
261 |
+
|
262 |
+
class DiagonalGaussian(nn.Module):
|
263 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
264 |
+
super().__init__()
|
265 |
+
self.sample = sample
|
266 |
+
self.chunk_dim = chunk_dim
|
267 |
+
|
268 |
+
def forward(self, z: Tensor) -> Tensor:
|
269 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
270 |
+
if self.sample:
|
271 |
+
std = torch.exp(0.5 * logvar)
|
272 |
+
return mean + std * torch.randn_like(mean)
|
273 |
+
else:
|
274 |
+
return mean
|
275 |
+
|
276 |
+
|
277 |
+
class AutoEncoder(nn.Module):
|
278 |
+
def __init__(self, params: AutoEncoderParams):
|
279 |
+
super().__init__()
|
280 |
+
self.encoder = Encoder(
|
281 |
+
resolution=params.resolution,
|
282 |
+
in_channels=params.in_channels,
|
283 |
+
ch=params.ch,
|
284 |
+
ch_mult=params.ch_mult,
|
285 |
+
num_res_blocks=params.num_res_blocks,
|
286 |
+
z_channels=params.z_channels,
|
287 |
+
)
|
288 |
+
self.decoder = Decoder(
|
289 |
+
resolution=params.resolution,
|
290 |
+
in_channels=params.in_channels,
|
291 |
+
ch=params.ch,
|
292 |
+
out_ch=params.out_ch,
|
293 |
+
ch_mult=params.ch_mult,
|
294 |
+
num_res_blocks=params.num_res_blocks,
|
295 |
+
z_channels=params.z_channels,
|
296 |
+
)
|
297 |
+
self.reg = DiagonalGaussian()
|
298 |
+
|
299 |
+
self.scale_factor = params.scale_factor
|
300 |
+
self.shift_factor = params.shift_factor
|
301 |
+
|
302 |
+
def encode(self, x: Tensor) -> Tensor:
|
303 |
+
z = self.reg(self.encoder(x))
|
304 |
+
z = self.scale_factor * (z - self.shift_factor)
|
305 |
+
return z
|
306 |
+
|
307 |
+
def decode(self, z: Tensor) -> Tensor:
|
308 |
+
z = z / self.scale_factor + self.shift_factor
|
309 |
+
return self.decoder(z)
|
310 |
+
|
311 |
+
def forward(self, x: Tensor) -> Tensor:
|
312 |
+
return self.decode(self.encode(x))
|
flux/modules/conditioner.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import Tensor, nn
|
2 |
+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
3 |
+
|
4 |
+
|
5 |
+
class HFEmbedder(nn.Module):
|
6 |
+
def __init__(self, version: str, max_length: int, **hf_kwargs):
|
7 |
+
super().__init__()
|
8 |
+
self.is_clip = version.startswith("openai")
|
9 |
+
self.max_length = max_length
|
10 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
11 |
+
|
12 |
+
if self.is_clip:
|
13 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
14 |
+
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
|
15 |
+
else:
|
16 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
|
17 |
+
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
|
18 |
+
|
19 |
+
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
20 |
+
|
21 |
+
def forward(self, text: list[str]) -> Tensor:
|
22 |
+
batch_encoding = self.tokenizer(
|
23 |
+
text,
|
24 |
+
truncation=True,
|
25 |
+
max_length=self.max_length,
|
26 |
+
return_length=False,
|
27 |
+
return_overflowing_tokens=False,
|
28 |
+
padding="max_length",
|
29 |
+
return_tensors="pt",
|
30 |
+
)
|
31 |
+
|
32 |
+
outputs = self.hf_module(
|
33 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
34 |
+
attention_mask=None,
|
35 |
+
output_hidden_states=False,
|
36 |
+
)
|
37 |
+
return outputs[self.output_key]
|
flux/modules/layers.py
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange
|
6 |
+
from torch import Tensor, nn
|
7 |
+
|
8 |
+
from flux.math import attention, rope
|
9 |
+
|
10 |
+
|
11 |
+
class EmbedND(nn.Module):
|
12 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
13 |
+
super().__init__()
|
14 |
+
self.dim = dim
|
15 |
+
self.theta = theta
|
16 |
+
self.axes_dim = axes_dim
|
17 |
+
|
18 |
+
def forward(self, ids: Tensor) -> Tensor:
|
19 |
+
n_axes = ids.shape[-1]
|
20 |
+
emb = torch.cat(
|
21 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
22 |
+
dim=-3,
|
23 |
+
)
|
24 |
+
|
25 |
+
return emb.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
29 |
+
"""
|
30 |
+
Create sinusoidal timestep embeddings.
|
31 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
32 |
+
These may be fractional.
|
33 |
+
:param dim: the dimension of the output.
|
34 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
35 |
+
:return: an (N, D) Tensor of positional embeddings.
|
36 |
+
"""
|
37 |
+
t = time_factor * t
|
38 |
+
half = dim // 2
|
39 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
40 |
+
t.device
|
41 |
+
)
|
42 |
+
|
43 |
+
args = t[:, None].float() * freqs[None]
|
44 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
45 |
+
if dim % 2:
|
46 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
47 |
+
if torch.is_floating_point(t):
|
48 |
+
embedding = embedding.to(t)
|
49 |
+
return embedding
|
50 |
+
|
51 |
+
|
52 |
+
class MLPEmbedder(nn.Module):
|
53 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
54 |
+
super().__init__()
|
55 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
56 |
+
self.silu = nn.SiLU()
|
57 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
58 |
+
|
59 |
+
def forward(self, x: Tensor) -> Tensor:
|
60 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
61 |
+
|
62 |
+
|
63 |
+
class RMSNorm(torch.nn.Module):
|
64 |
+
def __init__(self, dim: int):
|
65 |
+
super().__init__()
|
66 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
67 |
+
|
68 |
+
def forward(self, x: Tensor):
|
69 |
+
x_dtype = x.dtype
|
70 |
+
x = x.float()
|
71 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
72 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
73 |
+
|
74 |
+
|
75 |
+
class QKNorm(torch.nn.Module):
|
76 |
+
def __init__(self, dim: int):
|
77 |
+
super().__init__()
|
78 |
+
self.query_norm = RMSNorm(dim)
|
79 |
+
self.key_norm = RMSNorm(dim)
|
80 |
+
|
81 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
82 |
+
q = self.query_norm(q)
|
83 |
+
k = self.key_norm(k)
|
84 |
+
return q.to(v), k.to(v)
|
85 |
+
|
86 |
+
|
87 |
+
class SelfAttention(nn.Module):
|
88 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
89 |
+
super().__init__()
|
90 |
+
self.num_heads = num_heads
|
91 |
+
head_dim = dim // num_heads
|
92 |
+
|
93 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
94 |
+
self.norm = QKNorm(head_dim)
|
95 |
+
self.proj = nn.Linear(dim, dim)
|
96 |
+
|
97 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
98 |
+
qkv = self.qkv(x)
|
99 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
100 |
+
q, k = self.norm(q, k, v)
|
101 |
+
x = attention(q, k, v, pe=pe)
|
102 |
+
x = self.proj(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class ModulationOut:
|
108 |
+
shift: Tensor
|
109 |
+
scale: Tensor
|
110 |
+
gate: Tensor
|
111 |
+
|
112 |
+
|
113 |
+
class Modulation(nn.Module):
|
114 |
+
def __init__(self, dim: int, double: bool):
|
115 |
+
super().__init__()
|
116 |
+
self.is_double = double
|
117 |
+
self.multiplier = 6 if double else 3
|
118 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
119 |
+
|
120 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]:
|
121 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
122 |
+
|
123 |
+
return (
|
124 |
+
ModulationOut(*out[:3]),
|
125 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class DoubleStreamBlock(nn.Module):
|
130 |
+
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
|
131 |
+
super().__init__()
|
132 |
+
|
133 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
134 |
+
self.num_heads = num_heads
|
135 |
+
self.hidden_size = hidden_size
|
136 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
137 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
138 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
139 |
+
|
140 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
141 |
+
self.img_mlp = nn.Sequential(
|
142 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
143 |
+
nn.GELU(approximate="tanh"),
|
144 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
145 |
+
)
|
146 |
+
|
147 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
148 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
149 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
150 |
+
|
151 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
152 |
+
self.txt_mlp = nn.Sequential(
|
153 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
154 |
+
nn.GELU(approximate="tanh"),
|
155 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
156 |
+
)
|
157 |
+
|
158 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
159 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
160 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
161 |
+
|
162 |
+
# prepare image for attention
|
163 |
+
img_modulated = self.img_norm1(img)
|
164 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
165 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
166 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
167 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
168 |
+
|
169 |
+
# prepare txt for attention
|
170 |
+
txt_modulated = self.txt_norm1(txt)
|
171 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
172 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
173 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
174 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
175 |
+
|
176 |
+
# run actual attention
|
177 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
178 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
179 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
180 |
+
|
181 |
+
attn = attention(q, k, v, pe=pe)
|
182 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
183 |
+
|
184 |
+
# calculate the img bloks
|
185 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
186 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
187 |
+
|
188 |
+
# calculate the txt bloks
|
189 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
190 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
191 |
+
return img, txt
|
192 |
+
|
193 |
+
|
194 |
+
class SingleStreamBlock(nn.Module):
|
195 |
+
"""
|
196 |
+
A DiT block with parallel linear layers as described in
|
197 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
198 |
+
"""
|
199 |
+
|
200 |
+
def __init__(
|
201 |
+
self,
|
202 |
+
hidden_size: int,
|
203 |
+
num_heads: int,
|
204 |
+
mlp_ratio: float = 4.0,
|
205 |
+
qk_scale: float = None,
|
206 |
+
):
|
207 |
+
super().__init__()
|
208 |
+
self.hidden_dim = hidden_size
|
209 |
+
self.num_heads = num_heads
|
210 |
+
head_dim = hidden_size // num_heads
|
211 |
+
self.scale = qk_scale or head_dim**-0.5
|
212 |
+
|
213 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
214 |
+
# qkv and mlp_in
|
215 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
216 |
+
# proj and mlp_out
|
217 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
218 |
+
|
219 |
+
self.norm = QKNorm(head_dim)
|
220 |
+
|
221 |
+
self.hidden_size = hidden_size
|
222 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
223 |
+
|
224 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
225 |
+
self.modulation = Modulation(hidden_size, double=False)
|
226 |
+
|
227 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
228 |
+
mod, _ = self.modulation(vec)
|
229 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
230 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
231 |
+
|
232 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
233 |
+
q, k = self.norm(q, k, v)
|
234 |
+
|
235 |
+
# compute attention
|
236 |
+
attn = attention(q, k, v, pe=pe)
|
237 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
238 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
239 |
+
return x + mod.gate * output
|
240 |
+
|
241 |
+
|
242 |
+
class LastLayer(nn.Module):
|
243 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
244 |
+
super().__init__()
|
245 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
246 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
247 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
248 |
+
|
249 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
250 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
251 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
252 |
+
x = self.linear(x)
|
253 |
+
return x
|
flux/sampling.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Callable
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from einops import rearrange, repeat
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from .model import Flux
|
9 |
+
from .modules.conditioner import HFEmbedder
|
10 |
+
|
11 |
+
|
12 |
+
def get_noise(
|
13 |
+
num_samples: int,
|
14 |
+
height: int,
|
15 |
+
width: int,
|
16 |
+
device: torch.device,
|
17 |
+
dtype: torch.dtype,
|
18 |
+
seed: int,
|
19 |
+
):
|
20 |
+
return torch.randn(
|
21 |
+
num_samples,
|
22 |
+
16,
|
23 |
+
# allow for packing
|
24 |
+
2 * math.ceil(height / 16),
|
25 |
+
2 * math.ceil(width / 16),
|
26 |
+
device=device,
|
27 |
+
dtype=dtype,
|
28 |
+
generator=torch.Generator(device=device).manual_seed(seed),
|
29 |
+
)
|
30 |
+
|
31 |
+
|
32 |
+
def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str) -> dict[str, Tensor]:
|
33 |
+
bs, c, h, w = img.shape
|
34 |
+
if bs == 1 and not isinstance(prompt, str):
|
35 |
+
bs = len(prompt)
|
36 |
+
|
37 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
38 |
+
if img.shape[0] == 1 and bs > 1:
|
39 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
40 |
+
|
41 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
42 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
43 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
44 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
45 |
+
|
46 |
+
if isinstance(prompt, str):
|
47 |
+
prompt = [prompt]
|
48 |
+
txt = t5(prompt)
|
49 |
+
if txt.shape[0] == 1 and bs > 1:
|
50 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
51 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
52 |
+
|
53 |
+
vec = clip(prompt)
|
54 |
+
if vec.shape[0] == 1 and bs > 1:
|
55 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
56 |
+
|
57 |
+
return {
|
58 |
+
"img": img,
|
59 |
+
"img_ids": img_ids.to(img.device),
|
60 |
+
"txt": txt.to(img.device),
|
61 |
+
"txt_ids": txt_ids.to(img.device),
|
62 |
+
"vec": vec.to(img.device),
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def time_shift(mu: float, sigma: float, t: Tensor):
|
67 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
68 |
+
|
69 |
+
|
70 |
+
def get_lin_function(
|
71 |
+
x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
|
72 |
+
) -> Callable[[float], float]:
|
73 |
+
m = (y2 - y1) / (x2 - x1)
|
74 |
+
b = y1 - m * x1
|
75 |
+
return lambda x: m * x + b
|
76 |
+
|
77 |
+
|
78 |
+
def get_schedule(
|
79 |
+
num_steps: int,
|
80 |
+
image_seq_len: int,
|
81 |
+
base_shift: float = 0.5,
|
82 |
+
max_shift: float = 1.15,
|
83 |
+
shift: bool = True,
|
84 |
+
) -> list[float]:
|
85 |
+
# extra step for zero
|
86 |
+
timesteps = torch.linspace(1, 0, num_steps + 1)
|
87 |
+
|
88 |
+
# shifting the schedule to favor high timesteps for higher signal images
|
89 |
+
if shift:
|
90 |
+
# eastimate mu based on linear estimation between two points
|
91 |
+
mu = get_lin_function(y1=base_shift, y2=max_shift)(image_seq_len)
|
92 |
+
timesteps = time_shift(mu, 1.0, timesteps)
|
93 |
+
|
94 |
+
return timesteps.tolist()
|
95 |
+
|
96 |
+
|
97 |
+
def denoise(
|
98 |
+
model: Flux,
|
99 |
+
# model input
|
100 |
+
img: Tensor,
|
101 |
+
img_ids: Tensor,
|
102 |
+
txt: Tensor,
|
103 |
+
txt_ids: Tensor,
|
104 |
+
vec: Tensor,
|
105 |
+
timesteps: list[float],
|
106 |
+
guidance: float = 4.0,
|
107 |
+
id_weight=1.0,
|
108 |
+
id=None,
|
109 |
+
start_step=0,
|
110 |
+
uncond_id=None,
|
111 |
+
true_cfg=1.0,
|
112 |
+
timestep_to_start_cfg=1,
|
113 |
+
neg_txt=None,
|
114 |
+
neg_txt_ids=None,
|
115 |
+
neg_vec=None,
|
116 |
+
aggressive_offload=False,
|
117 |
+
):
|
118 |
+
# this is ignored for schnell
|
119 |
+
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
120 |
+
use_true_cfg = abs(true_cfg - 1.0) > 1e-2
|
121 |
+
for i, (t_curr, t_prev) in enumerate(zip(timesteps[:-1], timesteps[1:])):
|
122 |
+
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
123 |
+
pred = model(
|
124 |
+
img=img,
|
125 |
+
img_ids=img_ids,
|
126 |
+
txt=txt,
|
127 |
+
txt_ids=txt_ids,
|
128 |
+
y=vec,
|
129 |
+
timesteps=t_vec,
|
130 |
+
guidance=guidance_vec,
|
131 |
+
id=id if i >= start_step else None,
|
132 |
+
id_weight=id_weight,
|
133 |
+
aggressive_offload=aggressive_offload,
|
134 |
+
)
|
135 |
+
|
136 |
+
if use_true_cfg and i >= timestep_to_start_cfg:
|
137 |
+
neg_pred = model(
|
138 |
+
img=img,
|
139 |
+
img_ids=img_ids,
|
140 |
+
txt=neg_txt,
|
141 |
+
txt_ids=neg_txt_ids,
|
142 |
+
y=neg_vec,
|
143 |
+
timesteps=t_vec,
|
144 |
+
guidance=guidance_vec,
|
145 |
+
id=uncond_id if i >= start_step else None,
|
146 |
+
id_weight=id_weight,
|
147 |
+
aggressive_offload=aggressive_offload,
|
148 |
+
)
|
149 |
+
pred = neg_pred + true_cfg * (pred - neg_pred)
|
150 |
+
|
151 |
+
img = img + (t_prev - t_curr) * pred
|
152 |
+
|
153 |
+
return img
|
154 |
+
|
155 |
+
|
156 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
157 |
+
return rearrange(
|
158 |
+
x,
|
159 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
160 |
+
h=math.ceil(height / 16),
|
161 |
+
w=math.ceil(width / 16),
|
162 |
+
ph=2,
|
163 |
+
pw=2,
|
164 |
+
)
|
flux/util.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
from safetensors.torch import load_file as load_sft
|
8 |
+
|
9 |
+
from flux.model import Flux, FluxParams
|
10 |
+
from flux.modules.autoencoder import AutoEncoder, AutoEncoderParams
|
11 |
+
from flux.modules.conditioner import HFEmbedder
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class SamplingOptions:
|
16 |
+
prompt: str
|
17 |
+
width: int
|
18 |
+
height: int
|
19 |
+
num_steps: int
|
20 |
+
guidance: float
|
21 |
+
seed: int
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class ModelSpec:
|
26 |
+
params: FluxParams
|
27 |
+
ae_params: AutoEncoderParams
|
28 |
+
ckpt_path: str
|
29 |
+
ae_path: str
|
30 |
+
repo_id: str
|
31 |
+
repo_flow: str
|
32 |
+
repo_ae: str
|
33 |
+
|
34 |
+
|
35 |
+
configs = {
|
36 |
+
"flux-dev": ModelSpec(
|
37 |
+
repo_id="black-forest-labs/FLUX.1-dev",
|
38 |
+
repo_flow="flux1-dev.safetensors",
|
39 |
+
repo_ae="ae.safetensors",
|
40 |
+
ckpt_path='models/flux1-dev.safetensors',
|
41 |
+
params=FluxParams(
|
42 |
+
in_channels=64,
|
43 |
+
vec_in_dim=768,
|
44 |
+
context_in_dim=4096,
|
45 |
+
hidden_size=3072,
|
46 |
+
mlp_ratio=4.0,
|
47 |
+
num_heads=24,
|
48 |
+
depth=19,
|
49 |
+
depth_single_blocks=38,
|
50 |
+
axes_dim=[16, 56, 56],
|
51 |
+
theta=10_000,
|
52 |
+
qkv_bias=True,
|
53 |
+
guidance_embed=True,
|
54 |
+
),
|
55 |
+
ae_path='models/ae.safetensors',
|
56 |
+
ae_params=AutoEncoderParams(
|
57 |
+
resolution=256,
|
58 |
+
in_channels=3,
|
59 |
+
ch=128,
|
60 |
+
out_ch=3,
|
61 |
+
ch_mult=[1, 2, 4, 4],
|
62 |
+
num_res_blocks=2,
|
63 |
+
z_channels=16,
|
64 |
+
scale_factor=0.3611,
|
65 |
+
shift_factor=0.1159,
|
66 |
+
),
|
67 |
+
),
|
68 |
+
"flux-schnell": ModelSpec(
|
69 |
+
repo_id="black-forest-labs/FLUX.1-schnell",
|
70 |
+
repo_flow="flux1-schnell.safetensors",
|
71 |
+
repo_ae="ae.safetensors",
|
72 |
+
ckpt_path=os.getenv("FLUX_SCHNELL"),
|
73 |
+
params=FluxParams(
|
74 |
+
in_channels=64,
|
75 |
+
vec_in_dim=768,
|
76 |
+
context_in_dim=4096,
|
77 |
+
hidden_size=3072,
|
78 |
+
mlp_ratio=4.0,
|
79 |
+
num_heads=24,
|
80 |
+
depth=19,
|
81 |
+
depth_single_blocks=38,
|
82 |
+
axes_dim=[16, 56, 56],
|
83 |
+
theta=10_000,
|
84 |
+
qkv_bias=True,
|
85 |
+
guidance_embed=False,
|
86 |
+
),
|
87 |
+
ae_path=os.getenv("AE"),
|
88 |
+
ae_params=AutoEncoderParams(
|
89 |
+
resolution=256,
|
90 |
+
in_channels=3,
|
91 |
+
ch=128,
|
92 |
+
out_ch=3,
|
93 |
+
ch_mult=[1, 2, 4, 4],
|
94 |
+
num_res_blocks=2,
|
95 |
+
z_channels=16,
|
96 |
+
scale_factor=0.3611,
|
97 |
+
shift_factor=0.1159,
|
98 |
+
),
|
99 |
+
),
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
|
104 |
+
if len(missing) > 0 and len(unexpected) > 0:
|
105 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
106 |
+
print("\n" + "-" * 79 + "\n")
|
107 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
108 |
+
elif len(missing) > 0:
|
109 |
+
print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing))
|
110 |
+
elif len(unexpected) > 0:
|
111 |
+
print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected))
|
112 |
+
|
113 |
+
|
114 |
+
def load_flow_model(name: str, device: str = "cuda", hf_download: bool = True):
|
115 |
+
# Loading Flux
|
116 |
+
print("Init model")
|
117 |
+
ckpt_path = configs[name].ckpt_path
|
118 |
+
if (
|
119 |
+
not os.path.exists(ckpt_path)
|
120 |
+
and configs[name].repo_id is not None
|
121 |
+
and configs[name].repo_flow is not None
|
122 |
+
and hf_download
|
123 |
+
):
|
124 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow, local_dir='models')
|
125 |
+
|
126 |
+
with torch.device(device):
|
127 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
128 |
+
|
129 |
+
if ckpt_path is not None:
|
130 |
+
print("Loading checkpoint")
|
131 |
+
# load_sft doesn't support torch.device
|
132 |
+
sd = load_sft(ckpt_path, device=str(device))
|
133 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
134 |
+
print_load_warning(missing, unexpected)
|
135 |
+
return model
|
136 |
+
|
137 |
+
# from XLabs-AI https://github.com/XLabs-AI/x-flux/blob/1f8ef54972105ad9062be69fe6b7f841bce02a08/src/flux/util.py#L330
|
138 |
+
def load_flow_model_quintized(name: str, device: str = "cuda", hf_download: bool = True):
|
139 |
+
# Loading Flux
|
140 |
+
print("Init model")
|
141 |
+
ckpt_path = 'models/flux-dev-fp8.safetensors'
|
142 |
+
if (
|
143 |
+
not os.path.exists(ckpt_path)
|
144 |
+
and hf_download
|
145 |
+
):
|
146 |
+
ckpt_path = hf_hub_download("XLabs-AI/flux-dev-fp8", "flux-dev-fp8.safetensors")
|
147 |
+
json_path = hf_hub_download("XLabs-AI/flux-dev-fp8", 'flux_dev_quantization_map.json')
|
148 |
+
|
149 |
+
model = Flux(configs[name].params).to(torch.bfloat16)
|
150 |
+
|
151 |
+
print("Loading checkpoint")
|
152 |
+
# load_sft doesn't support torch.device
|
153 |
+
sd = load_sft(ckpt_path, device='cpu')
|
154 |
+
with open(json_path) as f:
|
155 |
+
quantization_map = json.load(f)
|
156 |
+
print("Start a quantization process...")
|
157 |
+
from optimum.quanto import requantize
|
158 |
+
requantize(model, sd, quantization_map, device=device)
|
159 |
+
print("Model is quantized!")
|
160 |
+
return model
|
161 |
+
|
162 |
+
|
163 |
+
def load_t5(device: str = "cuda", max_length: int = 512) -> HFEmbedder:
|
164 |
+
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
165 |
+
return HFEmbedder("xlabs-ai/xflux_text_encoders", max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
166 |
+
|
167 |
+
|
168 |
+
def load_clip(device: str = "cuda") -> HFEmbedder:
|
169 |
+
return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device)
|
170 |
+
|
171 |
+
|
172 |
+
def load_ae(name: str, device: str = "cuda", hf_download: bool = True) -> AutoEncoder:
|
173 |
+
ckpt_path = configs[name].ae_path
|
174 |
+
if (
|
175 |
+
not os.path.exists(ckpt_path)
|
176 |
+
and configs[name].repo_id is not None
|
177 |
+
and configs[name].repo_ae is not None
|
178 |
+
and hf_download
|
179 |
+
):
|
180 |
+
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae, local_dir='models')
|
181 |
+
|
182 |
+
# Loading the autoencoder
|
183 |
+
print("Init AE")
|
184 |
+
with torch.device(device):
|
185 |
+
ae = AutoEncoder(configs[name].ae_params)
|
186 |
+
|
187 |
+
if ckpt_path is not None:
|
188 |
+
sd = load_sft(ckpt_path, device=str(device))
|
189 |
+
missing, unexpected = ae.load_state_dict(sd, strict=False)
|
190 |
+
print_load_warning(missing, unexpected)
|
191 |
+
return ae
|
pulid/attention_processor.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
NUM_ZERO = 0
|
7 |
+
ORTHO = False
|
8 |
+
ORTHO_v2 = False
|
9 |
+
|
10 |
+
|
11 |
+
class AttnProcessor(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
def __call__(
|
16 |
+
self,
|
17 |
+
attn,
|
18 |
+
hidden_states,
|
19 |
+
encoder_hidden_states=None,
|
20 |
+
attention_mask=None,
|
21 |
+
temb=None,
|
22 |
+
id_embedding=None,
|
23 |
+
id_scale=1.0,
|
24 |
+
):
|
25 |
+
residual = hidden_states
|
26 |
+
|
27 |
+
if attn.spatial_norm is not None:
|
28 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
29 |
+
|
30 |
+
input_ndim = hidden_states.ndim
|
31 |
+
|
32 |
+
if input_ndim == 4:
|
33 |
+
batch_size, channel, height, width = hidden_states.shape
|
34 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
35 |
+
|
36 |
+
batch_size, sequence_length, _ = (
|
37 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
38 |
+
)
|
39 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
40 |
+
|
41 |
+
if attn.group_norm is not None:
|
42 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
43 |
+
|
44 |
+
query = attn.to_q(hidden_states)
|
45 |
+
|
46 |
+
if encoder_hidden_states is None:
|
47 |
+
encoder_hidden_states = hidden_states
|
48 |
+
elif attn.norm_cross:
|
49 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
50 |
+
|
51 |
+
key = attn.to_k(encoder_hidden_states)
|
52 |
+
value = attn.to_v(encoder_hidden_states)
|
53 |
+
|
54 |
+
query = attn.head_to_batch_dim(query)
|
55 |
+
key = attn.head_to_batch_dim(key)
|
56 |
+
value = attn.head_to_batch_dim(value)
|
57 |
+
|
58 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
59 |
+
hidden_states = torch.bmm(attention_probs, value)
|
60 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
61 |
+
|
62 |
+
# linear proj
|
63 |
+
hidden_states = attn.to_out[0](hidden_states)
|
64 |
+
# dropout
|
65 |
+
hidden_states = attn.to_out[1](hidden_states)
|
66 |
+
|
67 |
+
if input_ndim == 4:
|
68 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
69 |
+
|
70 |
+
if attn.residual_connection:
|
71 |
+
hidden_states = hidden_states + residual
|
72 |
+
|
73 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
74 |
+
|
75 |
+
return hidden_states
|
76 |
+
|
77 |
+
|
78 |
+
class IDAttnProcessor(nn.Module):
|
79 |
+
r"""
|
80 |
+
Attention processor for ID-Adapater.
|
81 |
+
Args:
|
82 |
+
hidden_size (`int`):
|
83 |
+
The hidden size of the attention layer.
|
84 |
+
cross_attention_dim (`int`):
|
85 |
+
The number of channels in the `encoder_hidden_states`.
|
86 |
+
scale (`float`, defaults to 1.0):
|
87 |
+
the weight scale of image prompt.
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, hidden_size, cross_attention_dim=None):
|
91 |
+
super().__init__()
|
92 |
+
self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
93 |
+
self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
94 |
+
|
95 |
+
def __call__(
|
96 |
+
self,
|
97 |
+
attn,
|
98 |
+
hidden_states,
|
99 |
+
encoder_hidden_states=None,
|
100 |
+
attention_mask=None,
|
101 |
+
temb=None,
|
102 |
+
id_embedding=None,
|
103 |
+
id_scale=1.0,
|
104 |
+
):
|
105 |
+
residual = hidden_states
|
106 |
+
|
107 |
+
if attn.spatial_norm is not None:
|
108 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
109 |
+
|
110 |
+
input_ndim = hidden_states.ndim
|
111 |
+
|
112 |
+
if input_ndim == 4:
|
113 |
+
batch_size, channel, height, width = hidden_states.shape
|
114 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
115 |
+
|
116 |
+
batch_size, sequence_length, _ = (
|
117 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
118 |
+
)
|
119 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
120 |
+
|
121 |
+
if attn.group_norm is not None:
|
122 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
123 |
+
|
124 |
+
query = attn.to_q(hidden_states)
|
125 |
+
|
126 |
+
if encoder_hidden_states is None:
|
127 |
+
encoder_hidden_states = hidden_states
|
128 |
+
elif attn.norm_cross:
|
129 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
130 |
+
|
131 |
+
key = attn.to_k(encoder_hidden_states)
|
132 |
+
value = attn.to_v(encoder_hidden_states)
|
133 |
+
|
134 |
+
query = attn.head_to_batch_dim(query)
|
135 |
+
key = attn.head_to_batch_dim(key)
|
136 |
+
value = attn.head_to_batch_dim(value)
|
137 |
+
|
138 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
139 |
+
hidden_states = torch.bmm(attention_probs, value)
|
140 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
141 |
+
|
142 |
+
# for id-adapter
|
143 |
+
if id_embedding is not None:
|
144 |
+
if NUM_ZERO == 0:
|
145 |
+
id_key = self.id_to_k(id_embedding)
|
146 |
+
id_value = self.id_to_v(id_embedding)
|
147 |
+
else:
|
148 |
+
zero_tensor = torch.zeros(
|
149 |
+
(id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
|
150 |
+
dtype=id_embedding.dtype,
|
151 |
+
device=id_embedding.device,
|
152 |
+
)
|
153 |
+
id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1))
|
154 |
+
id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1))
|
155 |
+
|
156 |
+
id_key = attn.head_to_batch_dim(id_key).to(query.dtype)
|
157 |
+
id_value = attn.head_to_batch_dim(id_value).to(query.dtype)
|
158 |
+
|
159 |
+
id_attention_probs = attn.get_attention_scores(query, id_key, None)
|
160 |
+
id_hidden_states = torch.bmm(id_attention_probs, id_value)
|
161 |
+
id_hidden_states = attn.batch_to_head_dim(id_hidden_states)
|
162 |
+
|
163 |
+
if not ORTHO:
|
164 |
+
hidden_states = hidden_states + id_scale * id_hidden_states
|
165 |
+
else:
|
166 |
+
projection = (
|
167 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
168 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
169 |
+
* hidden_states
|
170 |
+
)
|
171 |
+
orthogonal = id_hidden_states - projection
|
172 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
173 |
+
|
174 |
+
# linear proj
|
175 |
+
hidden_states = attn.to_out[0](hidden_states)
|
176 |
+
# dropout
|
177 |
+
hidden_states = attn.to_out[1](hidden_states)
|
178 |
+
|
179 |
+
if input_ndim == 4:
|
180 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
181 |
+
|
182 |
+
if attn.residual_connection:
|
183 |
+
hidden_states = hidden_states + residual
|
184 |
+
|
185 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
186 |
+
|
187 |
+
return hidden_states
|
188 |
+
|
189 |
+
|
190 |
+
class AttnProcessor2_0(nn.Module):
|
191 |
+
r"""
|
192 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self):
|
196 |
+
super().__init__()
|
197 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
198 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
199 |
+
|
200 |
+
def __call__(
|
201 |
+
self,
|
202 |
+
attn,
|
203 |
+
hidden_states,
|
204 |
+
encoder_hidden_states=None,
|
205 |
+
attention_mask=None,
|
206 |
+
temb=None,
|
207 |
+
id_embedding=None,
|
208 |
+
id_scale=1.0,
|
209 |
+
):
|
210 |
+
residual = hidden_states
|
211 |
+
|
212 |
+
if attn.spatial_norm is not None:
|
213 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
214 |
+
|
215 |
+
input_ndim = hidden_states.ndim
|
216 |
+
|
217 |
+
if input_ndim == 4:
|
218 |
+
batch_size, channel, height, width = hidden_states.shape
|
219 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
220 |
+
|
221 |
+
batch_size, sequence_length, _ = (
|
222 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
223 |
+
)
|
224 |
+
|
225 |
+
if attention_mask is not None:
|
226 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
227 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
228 |
+
# (batch, heads, source_length, target_length)
|
229 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
230 |
+
|
231 |
+
if attn.group_norm is not None:
|
232 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
233 |
+
|
234 |
+
query = attn.to_q(hidden_states)
|
235 |
+
|
236 |
+
if encoder_hidden_states is None:
|
237 |
+
encoder_hidden_states = hidden_states
|
238 |
+
elif attn.norm_cross:
|
239 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
240 |
+
|
241 |
+
key = attn.to_k(encoder_hidden_states)
|
242 |
+
value = attn.to_v(encoder_hidden_states)
|
243 |
+
|
244 |
+
inner_dim = key.shape[-1]
|
245 |
+
head_dim = inner_dim // attn.heads
|
246 |
+
|
247 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
248 |
+
|
249 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
250 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
251 |
+
|
252 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
253 |
+
hidden_states = F.scaled_dot_product_attention(
|
254 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
255 |
+
)
|
256 |
+
|
257 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
258 |
+
hidden_states = hidden_states.to(query.dtype)
|
259 |
+
|
260 |
+
# linear proj
|
261 |
+
hidden_states = attn.to_out[0](hidden_states)
|
262 |
+
# dropout
|
263 |
+
hidden_states = attn.to_out[1](hidden_states)
|
264 |
+
|
265 |
+
if input_ndim == 4:
|
266 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
267 |
+
|
268 |
+
if attn.residual_connection:
|
269 |
+
hidden_states = hidden_states + residual
|
270 |
+
|
271 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
272 |
+
|
273 |
+
return hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
class IDAttnProcessor2_0(torch.nn.Module):
|
277 |
+
r"""
|
278 |
+
Attention processor for ID-Adapater for PyTorch 2.0.
|
279 |
+
Args:
|
280 |
+
hidden_size (`int`):
|
281 |
+
The hidden size of the attention layer.
|
282 |
+
cross_attention_dim (`int`):
|
283 |
+
The number of channels in the `encoder_hidden_states`.
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(self, hidden_size, cross_attention_dim=None):
|
287 |
+
super().__init__()
|
288 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
289 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
290 |
+
|
291 |
+
self.id_to_k = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
292 |
+
self.id_to_v = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
293 |
+
|
294 |
+
def __call__(
|
295 |
+
self,
|
296 |
+
attn,
|
297 |
+
hidden_states,
|
298 |
+
encoder_hidden_states=None,
|
299 |
+
attention_mask=None,
|
300 |
+
temb=None,
|
301 |
+
id_embedding=None,
|
302 |
+
id_scale=1.0,
|
303 |
+
):
|
304 |
+
residual = hidden_states
|
305 |
+
|
306 |
+
if attn.spatial_norm is not None:
|
307 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
308 |
+
|
309 |
+
input_ndim = hidden_states.ndim
|
310 |
+
|
311 |
+
if input_ndim == 4:
|
312 |
+
batch_size, channel, height, width = hidden_states.shape
|
313 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
314 |
+
|
315 |
+
batch_size, sequence_length, _ = (
|
316 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
317 |
+
)
|
318 |
+
|
319 |
+
if attention_mask is not None:
|
320 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
321 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
322 |
+
# (batch, heads, source_length, target_length)
|
323 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
324 |
+
|
325 |
+
if attn.group_norm is not None:
|
326 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
327 |
+
|
328 |
+
query = attn.to_q(hidden_states)
|
329 |
+
|
330 |
+
if encoder_hidden_states is None:
|
331 |
+
encoder_hidden_states = hidden_states
|
332 |
+
elif attn.norm_cross:
|
333 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
334 |
+
|
335 |
+
key = attn.to_k(encoder_hidden_states)
|
336 |
+
value = attn.to_v(encoder_hidden_states)
|
337 |
+
|
338 |
+
inner_dim = key.shape[-1]
|
339 |
+
head_dim = inner_dim // attn.heads
|
340 |
+
|
341 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
342 |
+
|
343 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
344 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
345 |
+
|
346 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
347 |
+
hidden_states = F.scaled_dot_product_attention(
|
348 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
349 |
+
)
|
350 |
+
|
351 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
352 |
+
hidden_states = hidden_states.to(query.dtype)
|
353 |
+
|
354 |
+
# for id embedding
|
355 |
+
if id_embedding is not None:
|
356 |
+
if NUM_ZERO == 0:
|
357 |
+
id_key = self.id_to_k(id_embedding).to(query.dtype)
|
358 |
+
id_value = self.id_to_v(id_embedding).to(query.dtype)
|
359 |
+
else:
|
360 |
+
zero_tensor = torch.zeros(
|
361 |
+
(id_embedding.size(0), NUM_ZERO, id_embedding.size(-1)),
|
362 |
+
dtype=id_embedding.dtype,
|
363 |
+
device=id_embedding.device,
|
364 |
+
)
|
365 |
+
id_key = self.id_to_k(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
|
366 |
+
id_value = self.id_to_v(torch.cat((id_embedding, zero_tensor), dim=1)).to(query.dtype)
|
367 |
+
|
368 |
+
id_key = id_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
369 |
+
id_value = id_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
370 |
+
|
371 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
372 |
+
id_hidden_states = F.scaled_dot_product_attention(
|
373 |
+
query, id_key, id_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
374 |
+
)
|
375 |
+
|
376 |
+
id_hidden_states = id_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
377 |
+
id_hidden_states = id_hidden_states.to(query.dtype)
|
378 |
+
|
379 |
+
if not ORTHO and not ORTHO_v2:
|
380 |
+
hidden_states = hidden_states + id_scale * id_hidden_states
|
381 |
+
elif ORTHO_v2:
|
382 |
+
orig_dtype = hidden_states.dtype
|
383 |
+
hidden_states = hidden_states.to(torch.float32)
|
384 |
+
id_hidden_states = id_hidden_states.to(torch.float32)
|
385 |
+
attn_map = query @ id_key.transpose(-2, -1)
|
386 |
+
attn_mean = attn_map.softmax(dim=-1).mean(dim=1)
|
387 |
+
attn_mean = attn_mean[:, :, :5].sum(dim=-1, keepdim=True)
|
388 |
+
projection = (
|
389 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
390 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
391 |
+
* hidden_states
|
392 |
+
)
|
393 |
+
orthogonal = id_hidden_states + (attn_mean - 1) * projection
|
394 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
395 |
+
hidden_states = hidden_states.to(orig_dtype)
|
396 |
+
else:
|
397 |
+
orig_dtype = hidden_states.dtype
|
398 |
+
hidden_states = hidden_states.to(torch.float32)
|
399 |
+
id_hidden_states = id_hidden_states.to(torch.float32)
|
400 |
+
projection = (
|
401 |
+
torch.sum((hidden_states * id_hidden_states), dim=-2, keepdim=True)
|
402 |
+
/ torch.sum((hidden_states * hidden_states), dim=-2, keepdim=True)
|
403 |
+
* hidden_states
|
404 |
+
)
|
405 |
+
orthogonal = id_hidden_states - projection
|
406 |
+
hidden_states = hidden_states + id_scale * orthogonal
|
407 |
+
hidden_states = hidden_states.to(orig_dtype)
|
408 |
+
|
409 |
+
# linear proj
|
410 |
+
hidden_states = attn.to_out[0](hidden_states)
|
411 |
+
# dropout
|
412 |
+
hidden_states = attn.to_out[1](hidden_states)
|
413 |
+
|
414 |
+
if input_ndim == 4:
|
415 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
416 |
+
|
417 |
+
if attn.residual_connection:
|
418 |
+
hidden_states = hidden_states + residual
|
419 |
+
|
420 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
421 |
+
|
422 |
+
return hidden_states
|